Skip to content

Commit a6e916d

Browse files
authored
fix(index): add check for empty input in index (#1159) (#1161)
Signed-off-by: jinjiabao.jjb <jinjiabao.jjb@antgroup.com>
1 parent 00e6c1a commit a6e916d

File tree

2 files changed

+130
-0
lines changed

2 files changed

+130
-0
lines changed

src/index/index_impl.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,30 @@ class IndexImpl : public Index {
4848
this->inner_index_.reset();
4949
}
5050

51+
#define CHECK_AND_RETURN_EMPTY_DATASET \
52+
if (GetNumElements() == 0) { \
53+
return DatasetImpl::MakeEmptyDataset(); \
54+
}
55+
56+
#define CHECK_IMMUTABLE_INDEX(operation_str) \
57+
if (this->inner_index_->immutable_) { \
58+
return tl::unexpected(Error(ErrorType::UNSUPPORTED_INDEX_OPERATION, \
59+
"immutable index no support " operation_str)); \
60+
}
61+
62+
#define CHECK_NONEMPTY_DATASET(dataset) \
63+
if ((dataset)->GetNumElements() == 0) { \
64+
LOG_ERROR_AND_RETURNS(ErrorType::INVALID_ARGUMENT, "input dataset is empty"); \
65+
}
66+
5167
public:
5268
tl::expected<std::vector<int64_t>, Error>
5369
Build(const DatasetPtr& base) override {
5470
if (this->inner_index_->immutable_) {
5571
return tl::unexpected(
5672
Error(ErrorType::UNSUPPORTED_INDEX_OPERATION, "immutable index no support build"));
5773
}
74+
CHECK_NONEMPTY_DATASET(base);
5875
SAFE_CALL(return this->inner_index_->Build(base));
5976
}
6077

@@ -69,6 +86,7 @@ class IndexImpl : public Index {
6986
return tl::unexpected(
7087
Error(ErrorType::UNSUPPORTED_INDEX_OPERATION, "immutable index no support train"));
7188
}
89+
CHECK_NONEMPTY_DATASET(data);
7290
SAFE_CALL(this->inner_index_->Train(data));
7391
}
7492

@@ -78,6 +96,7 @@ class IndexImpl : public Index {
7896
return tl::unexpected(Error(ErrorType::UNSUPPORTED_INDEX_OPERATION,
7997
"immutable index no support continue build"));
8098
}
99+
CHECK_NONEMPTY_DATASET(base);
81100
SAFE_CALL(return this->inner_index_->ContinueBuild(base, binary_set));
82101
}
83102

@@ -87,6 +106,7 @@ class IndexImpl : public Index {
87106
return tl::unexpected(
88107
Error(ErrorType::UNSUPPORTED_INDEX_OPERATION, "immutable index no support add"));
89108
}
109+
CHECK_NONEMPTY_DATASET(base);
90110
SAFE_CALL(return this->inner_index_->Add(base));
91111
}
92112

@@ -114,6 +134,7 @@ class IndexImpl : public Index {
114134
return tl::unexpected(Error(ErrorType::UNSUPPORTED_INDEX_OPERATION,
115135
"immutable index no support update vector"));
116136
}
137+
CHECK_NONEMPTY_DATASET(new_base);
117138
SAFE_CALL(return this->inner_index_->UpdateVector(id, new_base, force_update));
118139
}
119140

@@ -144,6 +165,7 @@ class IndexImpl : public Index {
144165
return tl::unexpected(Error(ErrorType::UNSUPPORTED_INDEX_OPERATION,
145166
"immutable index no support update extra info"));
146167
}
168+
CHECK_NONEMPTY_DATASET(new_base);
147169
SAFE_CALL(return this->inner_index_->UpdateExtraInfo(new_base));
148170
}
149171

@@ -160,6 +182,7 @@ class IndexImpl : public Index {
160182
int64_t k,
161183
const std::string& parameters,
162184
BitsetPtr invalid = nullptr) const override {
185+
CHECK_NONEMPTY_DATASET(query);
163186
if (GetNumElements() == 0) {
164187
return DatasetImpl::MakeEmptyDataset();
165188
}
@@ -171,6 +194,7 @@ class IndexImpl : public Index {
171194
int64_t k,
172195
const std::string& parameters,
173196
const std::function<bool(int64_t)>& filter) const override {
197+
CHECK_NONEMPTY_DATASET(query);
174198
if (GetNumElements() == 0) {
175199
return DatasetImpl::MakeEmptyDataset();
176200
}
@@ -182,6 +206,7 @@ class IndexImpl : public Index {
182206
int64_t k,
183207
const std::string& parameters,
184208
const FilterPtr& filter) const override {
209+
CHECK_NONEMPTY_DATASET(query);
185210
if (GetNumElements() == 0) {
186211
return DatasetImpl::MakeEmptyDataset();
187212
}
@@ -190,6 +215,7 @@ class IndexImpl : public Index {
190215

191216
tl::expected<DatasetPtr, Error>
192217
KnnSearch(const DatasetPtr& query, int64_t k, SearchParam& search_param) const override {
218+
CHECK_NONEMPTY_DATASET(query);
193219
if (GetNumElements() == 0) {
194220
return DatasetImpl::MakeEmptyDataset();
195221
}
@@ -214,6 +240,7 @@ class IndexImpl : public Index {
214240
const FilterPtr& filter,
215241
IteratorContext*& iter_ctx,
216242
bool is_last_filter) const override {
243+
CHECK_NONEMPTY_DATASET(query);
217244
if (GetNumElements() == 0) {
218245
return DatasetImpl::MakeEmptyDataset();
219246
}
@@ -226,6 +253,7 @@ class IndexImpl : public Index {
226253
float radius,
227254
const std::string& parameters,
228255
int64_t limited_size = -1) const override {
256+
CHECK_NONEMPTY_DATASET(query);
229257
if (GetNumElements() == 0) {
230258
return DatasetImpl::MakeEmptyDataset();
231259
}
@@ -238,6 +266,7 @@ class IndexImpl : public Index {
238266
const std::string& parameters,
239267
BitsetPtr invalid,
240268
int64_t limited_size = -1) const override {
269+
CHECK_NONEMPTY_DATASET(query);
241270
if (GetNumElements() == 0) {
242271
return DatasetImpl::MakeEmptyDataset();
243272
}
@@ -251,6 +280,7 @@ class IndexImpl : public Index {
251280
const std::string& parameters,
252281
const std::function<bool(int64_t)>& filter,
253282
int64_t limited_size = -1) const override {
283+
CHECK_NONEMPTY_DATASET(query);
254284
if (GetNumElements() == 0) {
255285
return DatasetImpl::MakeEmptyDataset();
256286
}
@@ -264,6 +294,7 @@ class IndexImpl : public Index {
264294
const std::string& parameters,
265295
const FilterPtr& filter,
266296
int64_t limited_size = -1) const override {
297+
CHECK_NONEMPTY_DATASET(query);
267298
if (GetNumElements() == 0) {
268299
return DatasetImpl::MakeEmptyDataset();
269300
}
@@ -291,6 +322,7 @@ class IndexImpl : public Index {
291322
return tl::unexpected(Error(ErrorType::UNSUPPORTED_INDEX_OPERATION,
292323
"immutable index no support feedback"));
293324
}
325+
CHECK_NONEMPTY_DATASET(query);
294326
SAFE_CALL(return this->inner_index_->Feedback(query, k, parameters, global_optimum_tag_id));
295327
}
296328

src/index/index_impl_test.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,101 @@ TEST_CASE("immutable index test", "[ut][index_impl]") {
9191
REQUIRE_FALSE(result_merge.has_value());
9292
REQUIRE(result_merge.error().type == vsag::ErrorType::UNSUPPORTED_INDEX_OPERATION);
9393
}
94+
95+
TEST_CASE("index empty input test", "[ut][index_impl]") {
96+
vsag::IndexCommonParam common_param;
97+
common_param.dim_ = 128;
98+
common_param.data_type_ = vsag::DataTypes::DATA_TYPE_FLOAT;
99+
common_param.metric_ = vsag::MetricType::METRIC_TYPE_L2SQR;
100+
common_param.allocator_ = vsag::Engine::CreateDefaultAllocator();
101+
auto build_parameter_json = R"(
102+
{
103+
"base_quantization_type": "fp32",
104+
"max_degree": 16,
105+
"ef_construction": 100
106+
}
107+
)";
108+
109+
vsag::JsonType hgraph_json;
110+
hgraph_json = vsag::JsonType::parse(build_parameter_json);
111+
auto index = std::make_shared<vsag::IndexImpl<vsag::HGraph>>(hgraph_json, common_param);
112+
113+
vsag::DatasetPtr dataset = vsag::Dataset::Make();
114+
vsag::BinarySet binary_set;
115+
116+
auto result_build = index->Build(dataset);
117+
REQUIRE_FALSE(result_build.has_value());
118+
REQUIRE(result_build.error().type == vsag::ErrorType::INVALID_ARGUMENT);
119+
120+
auto result_train = index->Train(dataset);
121+
REQUIRE_FALSE(result_train.has_value());
122+
REQUIRE(result_train.error().type == vsag::ErrorType::INVALID_ARGUMENT);
123+
124+
auto result_continue_build = index->ContinueBuild(dataset, binary_set);
125+
REQUIRE_FALSE(result_continue_build.has_value());
126+
REQUIRE(result_continue_build.error().type == vsag::ErrorType::INVALID_ARGUMENT);
127+
128+
auto result_add = index->Add(dataset);
129+
REQUIRE_FALSE(result_add.has_value());
130+
REQUIRE(result_add.error().type == vsag::ErrorType::INVALID_ARGUMENT);
131+
132+
auto result_update_vector = index->UpdateVector(0, dataset);
133+
REQUIRE_FALSE(result_update_vector.has_value());
134+
REQUIRE(result_update_vector.error().type == vsag::ErrorType::INVALID_ARGUMENT);
135+
136+
auto result_update_extrainfo = index->UpdateExtraInfo(dataset);
137+
REQUIRE_FALSE(result_update_extrainfo.has_value());
138+
REQUIRE(result_update_extrainfo.error().type == vsag::ErrorType::INVALID_ARGUMENT);
139+
140+
auto result_feedback = index->Feedback(dataset, 0, "");
141+
REQUIRE_FALSE(result_feedback.has_value());
142+
REQUIRE(result_feedback.error().type == vsag::ErrorType::INVALID_ARGUMENT);
143+
144+
// test search empty dataset
145+
int64_t k = 0;
146+
float radius = 0.1;
147+
int64_t limited_size = 0;
148+
auto query = vsag::Dataset::Make();
149+
std::string parameters = "";
150+
auto invalid = vsag::Bitset::Make();
151+
auto filter = [](int64_t) -> bool { return true; };
152+
vsag::FilterPtr filter_ptr = nullptr;
153+
vsag::IteratorContext* iter_ctx = nullptr;
154+
vsag::SearchParam search_param(true, parameters, filter_ptr, nullptr);
155+
156+
auto search_result = index->KnnSearch(query, k, parameters, invalid);
157+
REQUIRE_FALSE(search_result.has_value());
158+
REQUIRE(search_result.error().type == vsag::ErrorType::INVALID_ARGUMENT);
159+
160+
search_result = index->KnnSearch(query, k, parameters, filter);
161+
REQUIRE_FALSE(search_result.has_value());
162+
REQUIRE(search_result.error().type == vsag::ErrorType::INVALID_ARGUMENT);
163+
164+
search_result = index->KnnSearch(query, k, parameters, filter_ptr);
165+
REQUIRE_FALSE(search_result.has_value());
166+
REQUIRE(search_result.error().type == vsag::ErrorType::INVALID_ARGUMENT);
167+
168+
search_result = index->KnnSearch(query, k, search_param);
169+
REQUIRE_FALSE(search_result.has_value());
170+
REQUIRE(search_result.error().type == vsag::ErrorType::INVALID_ARGUMENT);
171+
172+
search_result = index->KnnSearch(query, k, parameters, filter_ptr, iter_ctx, true);
173+
REQUIRE_FALSE(search_result.has_value());
174+
REQUIRE(search_result.error().type == vsag::ErrorType::INVALID_ARGUMENT);
175+
176+
search_result = index->RangeSearch(query, radius, parameters, limited_size);
177+
REQUIRE_FALSE(search_result.has_value());
178+
REQUIRE(search_result.error().type == vsag::ErrorType::INVALID_ARGUMENT);
179+
180+
search_result = index->RangeSearch(query, radius, parameters, invalid, limited_size);
181+
REQUIRE_FALSE(search_result.has_value());
182+
REQUIRE(search_result.error().type == vsag::ErrorType::INVALID_ARGUMENT);
183+
184+
search_result = index->RangeSearch(query, radius, parameters, filter, limited_size);
185+
REQUIRE_FALSE(search_result.has_value());
186+
REQUIRE(search_result.error().type == vsag::ErrorType::INVALID_ARGUMENT);
187+
188+
search_result = index->RangeSearch(query, radius, parameters, filter_ptr, limited_size);
189+
REQUIRE_FALSE(search_result.has_value());
190+
REQUIRE(search_result.error().type == vsag::ErrorType::INVALID_ARGUMENT);
191+
}

0 commit comments

Comments
 (0)