Skip to content

Commit 9ea8e48

Browse files
authored
feat: update lumina lib for diskann (alibaba#51)
1 parent 43d9d02 commit 9ea8e48

39 files changed

+713
-416
lines changed

src/paimon/global_index/lumina/lumina_api_test.cpp

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class LuminaInterfaceTest : public ::testing::Test {
3030
void TearDown() override {}
3131

3232
void WriteAndFlush(const std::string& index_path,
33-
const std::vector<::lumina::core::VectorId>& row_ids) const {
33+
const std::vector<::lumina::core::vector_id_t>& row_ids) const {
3434
auto fs = std::make_shared<LocalFileSystem>();
3535
std::shared_ptr<MemoryPool> paimon_pool = GetMemoryPool();
3636
auto pool = std::make_shared<LuminaMemoryPool>(paimon_pool);
@@ -43,8 +43,8 @@ class LuminaInterfaceTest : public ::testing::Test {
4343
.Set(::lumina::core::kEncodingType, ::lumina::core::kEncodingRawf32);
4444
auto builder_result =
4545
::lumina::api::LuminaBuilder::Create(builder_options, memory_resource);
46-
ASSERT_TRUE(builder_result.status.IsOk()) << builder_result.status.Message();
47-
auto writer = std::move(builder_result.value);
46+
ASSERT_TRUE(builder_result.IsOk()) << builder_result.GetStatus().Message();
47+
auto writer = std::move(builder_result).TakeValue();
4848
// pretrain
4949
ASSERT_TRUE(writer.Pretrain(/*data=*/nullptr, /*n=*/0).IsOk());
5050
// insert data
@@ -65,9 +65,9 @@ class LuminaInterfaceTest : public ::testing::Test {
6565
}
6666

6767
void Search(const std::string& index_path, int32_t topk,
68-
const std::vector<::lumina::core::VectorId>& expected_row_ids,
68+
const std::vector<::lumina::core::vector_id_t>& expected_row_ids,
6969
const std::vector<float>& expected_distances,
70-
const std::function<bool(::lumina::core::VectorId id)>& filter = nullptr) const {
70+
const std::function<bool(::lumina::core::vector_id_t id)>& filter = nullptr) const {
7171
ASSERT_EQ(expected_row_ids.size(), expected_distances.size());
7272
auto fs = std::make_shared<LocalFileSystem>();
7373
std::shared_ptr<MemoryPool> paimon_pool = GetMemoryPool();
@@ -77,12 +77,11 @@ class LuminaInterfaceTest : public ::testing::Test {
7777
// create reader
7878
::lumina::api::SearcherOptions searcher_options;
7979
searcher_options.Set(::lumina::core::kIndexType, ::lumina::core::kIndexTypeBruteforce)
80-
.Set(::lumina::core::kDimension, 4)
81-
.Set(::lumina::core::kSearchThreadCount, 10);
80+
.Set(::lumina::core::kDimension, 4);
8281
auto reader_result =
8382
::lumina::api::LuminaSearcher::Create(searcher_options, memory_resource);
84-
ASSERT_TRUE(reader_result.status.IsOk());
85-
auto reader = std::move(reader_result.value);
83+
ASSERT_TRUE(reader_result.IsOk());
84+
auto reader = std::move(reader_result).TakeValue();
8685
ASSERT_OK_AND_ASSIGN(std::shared_ptr<InputStream> in, fs->Open(index_path));
8786
auto file_reader = std::make_unique<LuminaFileReader>(in);
8887
ASSERT_TRUE(reader.Open(std::move(file_reader), ::lumina::api::IOOptions()).IsOk());
@@ -102,18 +101,20 @@ class LuminaInterfaceTest : public ::testing::Test {
102101
if (parallel_number > 0) {
103102
search_options.Set(::lumina::core::kSearchParallelNumber, parallel_number);
104103
}
105-
::lumina::core::Result<::lumina::api::LuminaSearcher::SearchResult> search_result;
104+
106105
if (!filter) {
107-
search_result = reader.Search(query, search_options, *pool);
106+
auto search_result = reader.Search(query, search_options, *pool);
107+
ASSERT_TRUE(search_result.IsOk()) << search_result.GetStatus().Message();
108+
CheckResult(search_result.Value().topk, expected_row_ids, expected_distances);
108109
} else {
109110
search_options.Set(::lumina::core::kSearchThreadSafeFilter, true);
110111
::lumina::extensions::SearchWithFilterExtension reader_with_filter;
111112
ASSERT_TRUE(reader.Attach(reader_with_filter).IsOk());
112-
search_result =
113+
auto search_result =
113114
reader_with_filter.SearchWithFilter(query, filter, search_options, *pool);
115+
ASSERT_TRUE(search_result.IsOk()) << search_result.GetStatus().Message();
116+
CheckResult(search_result.Value().topk, expected_row_ids, expected_distances);
114117
}
115-
ASSERT_TRUE(search_result.status.IsOk()) << search_result.status.Message();
116-
CheckResult(search_result.value.topk, expected_row_ids, expected_distances);
117118

118119
// TODO(xinyu.lxy): check memory paimon_pool, current memory use = query mem +
119120
// reader mem
@@ -133,7 +134,7 @@ class LuminaInterfaceTest : public ::testing::Test {
133134
}
134135

135136
void CheckResult(const std::vector<::lumina::api::LuminaSearcher::SearchHit>& search_result,
136-
const std::vector<::lumina::core::VectorId>& expected_row_ids,
137+
const std::vector<::lumina::core::vector_id_t>& expected_row_ids,
137138
const std::vector<float>& expected_distances) const {
138139
ASSERT_EQ(search_result.size(), expected_row_ids.size());
139140
for (size_t i = 0; i < search_result.size(); i++) {
@@ -155,11 +156,11 @@ TEST_F(LuminaInterfaceTest, TestSimple) {
155156
std::string index_path = dir->Str() + "/lumina_test.index";
156157

157158
// write index
158-
std::vector<::lumina::core::VectorId> row_ids = {0l, 1l, 2l, 3l};
159+
std::vector<::lumina::core::vector_id_t> row_ids = {0l, 1l, 2l, 3l};
159160
WriteAndFlush(index_path, row_ids);
160161

161162
// read index
162-
std::vector<::lumina::core::VectorId> expected_row_ids = {3l, 1l, 2l, 0l};
163+
std::vector<::lumina::core::vector_id_t> expected_row_ids = {3l, 1l, 2l, 0l};
163164
std::vector<float> expected_distances = {0.01f, 2.01f, 2.21f, 4.21f};
164165
Search(index_path, /*topk=*/4, expected_row_ids, expected_distances);
165166
}
@@ -169,11 +170,11 @@ TEST_F(LuminaInterfaceTest, TestWithDocIdGap) {
169170
std::string index_path = dir->Str() + "/lumina_test.index";
170171

171172
// write index
172-
std::vector<::lumina::core::VectorId> row_ids = {0l, 2l, 4l, 6l};
173+
std::vector<::lumina::core::vector_id_t> row_ids = {0l, 2l, 4l, 6l};
173174
WriteAndFlush(index_path, row_ids);
174175

175176
// read index
176-
std::vector<::lumina::core::VectorId> expected_row_ids = {6l, 2l, 4l, 0l};
177+
std::vector<::lumina::core::vector_id_t> expected_row_ids = {6l, 2l, 4l, 0l};
177178
std::vector<float> expected_distances = {0.01f, 2.01f, 2.21f, 4.21f};
178179
Search(index_path, /*topk=*/4, expected_row_ids, expected_distances);
179180
}
@@ -183,11 +184,11 @@ TEST_F(LuminaInterfaceTest, TestWithSmallTopk) {
183184
std::string index_path = dir->Str() + "/lumina_test.index";
184185

185186
// write index
186-
std::vector<::lumina::core::VectorId> row_ids = {0l, 1l, 2l, 3l};
187+
std::vector<::lumina::core::vector_id_t> row_ids = {0l, 1l, 2l, 3l};
187188
WriteAndFlush(index_path, row_ids);
188189

189190
// read index
190-
std::vector<::lumina::core::VectorId> expected_row_ids = {3l, 1l, 2l};
191+
std::vector<::lumina::core::vector_id_t> expected_row_ids = {3l, 1l, 2l};
191192
std::vector<float> expected_distances = {0.01f, 2.01f, 2.21f};
192193
Search(index_path, /*topk=*/3, expected_row_ids, expected_distances);
193194
}
@@ -197,20 +198,20 @@ TEST_F(LuminaInterfaceTest, TestWithFilter) {
197198
std::string index_path = dir->Str() + "/lumina_test.index";
198199

199200
// write index
200-
std::vector<::lumina::core::VectorId> row_ids = {0l, 1l, 2l, 3l};
201+
std::vector<::lumina::core::vector_id_t> row_ids = {0l, 1l, 2l, 3l};
201202
WriteAndFlush(index_path, row_ids);
202203

203204
// read index
204205
{
205-
std::vector<::lumina::core::VectorId> expected_row_ids = {1l, 2l};
206+
std::vector<::lumina::core::vector_id_t> expected_row_ids = {1l, 2l};
206207
std::vector<float> expected_distances = {2.01f, 2.21f};
207-
auto filter = [](::lumina::core::VectorId id) -> bool { return id < 3; };
208+
auto filter = [](::lumina::core::vector_id_t id) -> bool { return id < 3; };
208209
Search(index_path, /*topk=*/2, expected_row_ids, expected_distances, filter);
209210
}
210211
{
211-
std::vector<::lumina::core::VectorId> expected_row_ids = {1l, 2l, 0l};
212+
std::vector<::lumina::core::vector_id_t> expected_row_ids = {1l, 2l, 0l};
212213
std::vector<float> expected_distances = {2.01f, 2.21f, 4.21f};
213-
auto filter = [](::lumina::core::VectorId id) -> bool { return id < 3; };
214+
auto filter = [](::lumina::core::vector_id_t id) -> bool { return id < 3; };
214215
Search(index_path, /*topk=*/4, expected_row_ids, expected_distances, filter);
215216
}
216217
}

src/paimon/global_index/lumina/lumina_file_reader.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class LuminaFileReader : public ::lumina::io::FileReader {
6262
return PaimonToLuminaStatus(read_result.status());
6363
}
6464
if (static_cast<uint64_t>(read_result.value()) != current_read_size) {
65-
return ::lumina::core::Status::Error(
65+
return ::lumina::core::Status(
6666
::lumina::core::ErrorCode::IoError,
6767
fmt::format("expect read len {} mismatch actual read len {}", current_read_size,
6868
read_result.value()));

src/paimon/global_index/lumina/lumina_file_writer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class LuminaFileWriter : public ::lumina::io::FileWriter {
4848
return PaimonToLuminaStatus(write_result.status());
4949
}
5050
if (static_cast<uint64_t>(write_result.value()) != current_write_size) {
51-
return ::lumina::core::Status::Error(
51+
return ::lumina::core::Status(
5252
::lumina::core::ErrorCode::IoError,
5353
fmt::format("expect write len {} mismatch actual write len {}",
5454
current_write_size, write_result.value()));

src/paimon/global_index/lumina/lumina_global_index.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ class LuminaDataset : public ::lumina::api::Dataset {
164164

165165
::lumina::core::Result<uint64_t> GetNextBatch(
166166
std::vector<float>& vector_buffer,
167-
std::vector<::lumina::core::VectorId>& id_buffer) override {
167+
std::vector<::lumina::core::vector_id_t>& id_buffer) override {
168168
if (cursor_ >= array_vec_.size()) {
169169
return ::lumina::core::Result<uint64_t>::Ok(0);
170170
}
@@ -189,7 +189,7 @@ class LuminaDataset : public ::lumina::api::Dataset {
189189
uint32_t dimension_;
190190
std::vector<std::shared_ptr<arrow::FloatArray>> array_vec_;
191191
size_t cursor_ = 0;
192-
::lumina::core::VectorId id_ = 0;
192+
::lumina::core::vector_id_t id_ = 0;
193193
};
194194

195195
LuminaIndexWriter::LuminaIndexWriter(const std::string& field_name,
@@ -208,7 +208,6 @@ LuminaIndexWriter::LuminaIndexWriter(const std::string& field_name,
208208
io_options_(std::move(io_options)) {}
209209

210210
Status LuminaIndexWriter::AddBatch(::ArrowArray* arrow_array) {
211-
// TODO(xinyu.lxy): may use async thread to read data and build index
212211
PAIMON_ASSIGN_OR_RAISE_FROM_ARROW(std::shared_ptr<arrow::Array> array,
213212
arrow::ImportArray(arrow_array, arrow_type_));
214213
if (array->null_count() != 0) {
@@ -297,7 +296,7 @@ Result<std::shared_ptr<VectorSearchGlobalIndexResult>> LuminaIndexReader::VisitV
297296
} else {
298297
search_options.Set(::lumina::core::kSearchThreadSafeFilter, true);
299298
auto lumina_filter = [filter = vector_search->pre_filter](
300-
::lumina::core::VectorId id) -> bool { return filter(id); };
299+
::lumina::core::vector_id_t id) -> bool { return filter(id); };
301300
PAIMON_ASSIGN_OR_RAISE_FROM_LUMINA(
302301
search_result, searcher_with_filter_->SearchWithFilter(lumina_query, lumina_filter,
303302
search_options, *pool_));

src/paimon/global_index/lumina/lumina_global_index_test.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,11 @@ class LuminaGlobalIndexTest : public ::testing::Test {
159159
private:
160160
std::shared_ptr<MemoryPool> pool_ = GetDefaultPool();
161161
std::shared_ptr<FileSystem> fs_ = std::make_shared<LocalFileSystem>();
162-
std::map<std::string, std::string> options_ = {{"lumina.dimension", "4"},
163-
{"lumina.indextype", "bruteforce"},
162+
std::map<std::string, std::string> options_ = {{"lumina.index.dimension", "4"},
163+
{"lumina.index.type", "bruteforce"},
164164
{"lumina.distance.metric", "l2"},
165-
{"lumina.encoding.type", "encoding.rawf32"},
166-
{"lumina.search.threadcount", "10"}};
165+
{"lumina.encoding.type", "rawf32"},
166+
{"lumina.search.thread_count", "10"}};
167167
std::shared_ptr<arrow::DataType> data_type_ =
168168
arrow::struct_({arrow::field("f0", arrow::list(arrow::float32()))});
169169
std::shared_ptr<arrow::Array> array_ = arrow::ipc::internal::json::ArrayFromJSON(data_type_,
@@ -250,15 +250,15 @@ TEST_F(LuminaGlobalIndexTest, TestInvalidInputs) {
250250
{
251251
// invalid options
252252
std::map<std::string, std::string> options = options_;
253-
options["lumina.dimension"] = "xxx";
253+
options["lumina.index.dimension"] = "xxx";
254254
ASSERT_NOK_WITH_MSG(
255255
WriteGlobalIndex(index_root, data_type_, options, /*array=*/nullptr, Range(0, 0)),
256-
"convert key lumina.dimension, value xxx to unsigned int failed");
256+
"convert key lumina.index.dimension, value xxx to unsigned int failed");
257257
GlobalIndexIOMeta fake_meta("fake_file_name", /*file_size=*/10,
258258
/*range_end=*/5,
259259
/*metadata=*/nullptr);
260260
ASSERT_NOK_WITH_MSG(CreateGlobalIndexReader(index_root, data_type_, options, fake_meta),
261-
"convert key lumina.dimension, value xxx to unsigned int failed");
261+
"convert key lumina.index.dimension, value xxx to unsigned int failed");
262262
}
263263
{
264264
// invalid inputs in write
@@ -367,7 +367,7 @@ TEST_F(LuminaGlobalIndexTest, TestInvalidInputs) {
367367
}
368368
{
369369
std::map<std::string, std::string> options = options_;
370-
options["lumina.dimension"] = "5";
370+
options["lumina.index.dimension"] = "5";
371371
ASSERT_NOK_WITH_MSG(CreateGlobalIndexReader(index_root, data_type_, options, meta),
372372
"lumina index dimension 4 mismatch dimension 5 in options");
373373
}

src/paimon/global_index/lumina/lumina_utils.h

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ namespace paimon::lumina {
3030
} \
3131
} while (false)
3232

33-
#define PAIMON_ASSIGN_OR_RAISE_IMPL_FROM_LUMINA(result_name, lhs, rexpr) \
34-
auto&& result_name = (rexpr); \
35-
PAIMON_RETURN_IF_(!(result_name).IsOk(), LuminaToPaimonStatus((result_name).status), \
36-
PAIMON_STRINGIFY(rexpr)); \
37-
lhs = std::move(result_name.value);
33+
#define PAIMON_ASSIGN_OR_RAISE_IMPL_FROM_LUMINA(result_name, lhs, rexpr) \
34+
auto&& result_name = (rexpr); \
35+
PAIMON_RETURN_IF_(!(result_name).IsOk(), LuminaToPaimonStatus((result_name).GetStatus()), \
36+
PAIMON_STRINGIFY(rexpr)); \
37+
lhs = std::move(result_name).TakeValue();
3838

3939
#define PAIMON_ASSIGN_OR_RAISE_FROM_LUMINA(lhs, rexpr) \
4040
PAIMON_ASSIGN_OR_RAISE_IMPL_FROM_LUMINA( \
@@ -45,23 +45,20 @@ inline ::lumina::core::Status PaimonToLuminaStatus(const Status& status) {
4545
case StatusCode::OK:
4646
return ::lumina::core::Status::Ok();
4747
case StatusCode::OutOfMemory:
48-
return ::lumina::core::Status::Error(::lumina::core::ErrorCode::OutOfMemory,
49-
status.message());
48+
return ::lumina::core::Status(::lumina::core::ErrorCode::OutOfMemory, status.message());
5049
case StatusCode::IOError:
51-
return ::lumina::core::Status::Error(::lumina::core::ErrorCode::IoError,
52-
status.message());
50+
return ::lumina::core::Status(::lumina::core::ErrorCode::IoError, status.message());
5351
case StatusCode::NotImplemented:
54-
return ::lumina::core::Status::Error(::lumina::core::ErrorCode::NotSupported,
55-
status.message());
52+
return ::lumina::core::Status(::lumina::core::ErrorCode::NotSupported,
53+
status.message());
5654
case StatusCode::NotExist:
57-
return ::lumina::core::Status::Error(::lumina::core::ErrorCode::NotFound,
58-
status.message());
55+
return ::lumina::core::Status(::lumina::core::ErrorCode::NotFound, status.message());
5956
case StatusCode::Exist:
60-
return ::lumina::core::Status::Error(::lumina::core::ErrorCode::AlreadyExists,
61-
status.message());
57+
return ::lumina::core::Status(::lumina::core::ErrorCode::AlreadyExists,
58+
status.message());
6259
default:
63-
return ::lumina::core::Status::Error(::lumina::core::ErrorCode::InvalidArgument,
64-
status.message());
60+
return ::lumina::core::Status(::lumina::core::ErrorCode::InvalidArgument,
61+
status.message());
6562
}
6663
}
6764

0 commit comments

Comments
 (0)