Skip to content

Commit a2d107d

Browse files
committed
feat(index): add write function for Serialize
Signed-off-by: LHT129 <tianlan.lht@antgroup.com>
1 parent d9ac223 commit a2d107d

File tree

12 files changed

+105
-0
lines changed

12 files changed

+105
-0
lines changed

include/vsag/index.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ enum class IndexType { HNSW, DISKANN, HGRAPH, IVF, PYRAMID, BRUTEFORCE, SPARSE,
5151
#define DATA_FLAG_ATTRIBUTE 0x20
5252
#define DATA_FLAG_ID 0x40
5353

54+
using OffsetType = uint64_t;
55+
using SizeType = uint64_t;
56+
using WriteFuncType = std::function<void(OffsetType, SizeType, const void*)>;
57+
5458
class Index {
5559
public:
5660
// [basic methods]
@@ -637,6 +641,16 @@ class Index {
637641
[[nodiscard]] virtual tl::expected<BinarySet, Error>
638642
Serialize() const = 0;
639643

644+
/**
645+
* @brief Serialize index by write function
646+
*
647+
* @param write_func is a function to write serialized index
648+
*/
649+
[[nodiscard]] virtual tl::expected<void, Error>
650+
Serialize(WriteFuncType write_func) const {
651+
throw std::runtime_error("Index doesn't support Serialize with write function");
652+
}
653+
640654
/**
641655
* @brief Deserialize index from a set of byte array. Causing exception if this index is not empty
642656
*

include/vsag/index_features.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ enum IndexFeature {
3535
SUPPORT_METRIC_TYPE_COSINE, /**< Supports cosine metric type */
3636
SUPPORT_SERIALIZE_FILE, /**< Supports serialization to a file */
3737
SUPPORT_SERIALIZE_BINARY_SET, /**< Supports serialization to a binary set */
38+
SUPPORT_SERIALIZE_WRITE_FUNC, /**< Supports serialization to a write function */
3839
SUPPORT_DESERIALIZE_FILE, /**< Supports deserialization from a file */
3940
SUPPORT_DESERIALIZE_BINARY_SET, /**< Supports deserialization from a binary set */
4041
SUPPORT_DESERIALIZE_READER_SET, /**< Supports deserialization from a reader set */

src/algorithm/brute_force.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ BruteForce::InitFeatures() {
422422
IndexFeature::SUPPORT_DESERIALIZE_READER_SET,
423423
IndexFeature::SUPPORT_SERIALIZE_BINARY_SET,
424424
IndexFeature::SUPPORT_SERIALIZE_FILE,
425+
IndexFeature::SUPPORT_SERIALIZE_WRITE_FUNC,
425426
});
426427

427428
// others

src/algorithm/hgraph.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,6 +1116,7 @@ HGraph::InitFeatures() {
11161116
IndexFeature::SUPPORT_DESERIALIZE_READER_SET,
11171117
IndexFeature::SUPPORT_SERIALIZE_BINARY_SET,
11181118
IndexFeature::SUPPORT_SERIALIZE_FILE,
1119+
IndexFeature::SUPPORT_SERIALIZE_WRITE_FUNC,
11191120
});
11201121
// other
11211122
this->index_feature_list_->SetFeatures({

src/algorithm/inner_index_interface.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,16 @@ InnerIndexInterface::Serialize() const {
143143
return bs;
144144
}
145145

146+
void
147+
InnerIndexInterface::Serialize(const WriteFuncType& write_func) const {
148+
std::string time_record_name = this->GetName() + " Serialize";
149+
SlowTaskTimer t(time_record_name);
150+
151+
uint64_t num_bytes = this->CalSerializeSize();
152+
WriteFuncStreamWriter writer(write_func, 0);
153+
this->Serialize(writer);
154+
}
155+
146156
void
147157
InnerIndexInterface::Deserialize(const BinarySet& binary_set) {
148158
std::string time_record_name = this->GetName() + " Deserialize";

src/algorithm/inner_index_interface.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,9 @@ class InnerIndexInterface {
329329
virtual void
330330
Serialize(std::ostream& out_stream) const;
331331

332+
virtual void
333+
Serialize(const WriteFuncType& write_func) const;
334+
332335
virtual void
333336
Serialize(StreamWriter& writer) const = 0;
334337

src/algorithm/ivf.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ IVF::InitFeatures() {
296296
IndexFeature::SUPPORT_DESERIALIZE_READER_SET,
297297
IndexFeature::SUPPORT_SERIALIZE_BINARY_SET,
298298
IndexFeature::SUPPORT_SERIALIZE_FILE,
299+
IndexFeature::SUPPORT_SERIALIZE_WRITE_FUNC,
299300
});
300301

301302
auto name = this->bucket_->GetQuantizerName();

src/index/index_impl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,11 @@ class IndexImpl : public Index {
376376
SAFE_CALL(return this->inner_index_->Serialize());
377377
}
378378

379+
[[nodiscard]] tl::expected<void, Error>
380+
Serialize(WriteFuncType write_func) const override {
381+
SAFE_CALL(this->inner_index_->Serialize(write_func));
382+
}
383+
379384
tl::expected<void, Error>
380385
Serialize(std::ostream& out_stream) override {
381386
SAFE_CALL(this->inner_index_->Serialize(out_stream));

tests/test_hgraph.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,8 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HGraphTestIndex,
678678
TestIndex::TestSerializeBinarySet(index, index2, dataset, search_param, true);
679679
index2 = TestIndex::TestFactory(name, param, true);
680680
TestIndex::TestSerializeReaderSet(index, index2, dataset, search_param, name, true);
681+
index2 = TestIndex::TestFactory(name, param, true);
682+
TestIndex::TestSerializeWriteFunc(index, index2, dataset, search_param, true);
681683
vsag::Options::Instance().set_block_size_limit(origin_size);
682684
}
683685
static void
@@ -1413,6 +1415,8 @@ TestHGraphSerialize(const fixtures::HGraphTestIndexPtr& test_index,
14131415
index2 = TestIndex::TestFactory(test_index->name, param, true);
14141416
TestIndex::TestSerializeReaderSet(
14151417
index, index2, dataset, search_param, test_index->name, true);
1418+
index2 = TestIndex::TestFactory(test_index->name, param, true);
1419+
TestIndex::TestSerializeWriteFunc(index, index2, dataset, search_param, true);
14161420
vsag::Options::Instance().set_block_size_limit(origin_size);
14171421
}
14181422
}

tests/test_index.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,58 @@ TestIndex::TestSerializeReaderSet(const IndexPtr& index_from,
10861086
}
10871087
}
10881088

1089+
void
1090+
TestIndex::TestSerializeWriteFunc(const IndexPtr& index_from,
1091+
const IndexPtr& index_to,
1092+
const TestDatasetPtr& dataset,
1093+
const std::string& search_param,
1094+
bool expected_success) {
1095+
if (not index_from->CheckFeature(vsag::SUPPORT_SERIALIZE_WRITE_FUNC)) {
1096+
return;
1097+
}
1098+
auto dir = fixtures::TempDir("serialize");
1099+
auto path = dir.GenerateRandomFile();
1100+
std::ofstream outfile(path, std::ios::out | std::ios::binary);
1101+
vsag::WriteFuncType write_func =
1102+
[&outfile](vsag::OffsetType offset, vsag::SizeType size, const void* data) -> void {
1103+
outfile.seekp(offset);
1104+
outfile.write(reinterpret_cast<const char*>(data), size);
1105+
};
1106+
auto serialize_index = index_from->Serialize(write_func);
1107+
REQUIRE(serialize_index.has_value() == expected_success);
1108+
outfile.close();
1109+
1110+
std::ifstream infile(path, std::ios::in | std::ios::binary);
1111+
auto deserialize_index = index_to->Deserialize(infile);
1112+
REQUIRE(deserialize_index.has_value() == expected_success);
1113+
infile.close();
1114+
if (index_to->GetNumElements() == 0) {
1115+
return;
1116+
}
1117+
1118+
const auto& queries = dataset->query_;
1119+
auto query_count = queries->GetNumElements();
1120+
auto dim = queries->GetDim();
1121+
auto topk = 10;
1122+
for (auto i = 0; i < query_count; ++i) {
1123+
auto query = vsag::Dataset::Make();
1124+
query->NumElements(1)
1125+
->Dim(dim)
1126+
->Paths(queries->GetPaths() + i)
1127+
->SparseVectors(queries->GetSparseVectors() + i)
1128+
->Float32Vectors(queries->GetFloat32Vectors() + i * dim)
1129+
->Owner(false);
1130+
auto res_from = index_from->KnnSearch(query, topk, search_param);
1131+
auto res_to = index_to->KnnSearch(query, topk, search_param);
1132+
REQUIRE(res_from.has_value());
1133+
REQUIRE(res_to.has_value());
1134+
REQUIRE(res_from.value()->GetDim() == res_to.value()->GetDim());
1135+
for (auto j = 0; j < topk; ++j) {
1136+
REQUIRE(res_to.value()->GetIds()[j] == res_from.value()->GetIds()[j]);
1137+
}
1138+
}
1139+
}
1140+
10891141
void
10901142
TestIndex::TestConcurrentAddSearch(const TestIndex::IndexPtr& index,
10911143
const TestDatasetPtr& dataset,

0 commit comments

Comments
 (0)