Skip to content

Commit 89d83dc

Browse files
ShawnShawnYouwxyucs
authored andcommitted
fix(SINDI): fix term out of bound (#1383)
Signed-off-by: zhongxiaoyao.zxy <zhongxiaoyao.zxy@antgroup.com>
1 parent ac331c0 commit 89d83dc

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

src/data_cell/sparse_term_datacell.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,10 @@ SparseTermDataCell::Query(float* global_dists, const SparseTermComputerPtr& comp
2525
if (computer->HasNextTerm()) {
2626
auto next_it = it + 1;
2727
auto next_term = computer->GetTerm(next_it);
28-
if (next_term >= term_ids_.size()) {
29-
continue;
28+
if (next_term < term_ids_.size()) {
29+
__builtin_prefetch(term_ids_[next_term].data(), 0, 3);
30+
__builtin_prefetch(term_datas_[next_term].data(), 0, 3);
3031
}
31-
__builtin_prefetch(term_ids_[next_term].data(), 0, 3);
32-
__builtin_prefetch(term_datas_[next_term].data(), 0, 3);
3332
}
3433
if (term >= term_ids_.size()) {
3534
continue;

src/data_cell/sparse_term_datacell_test.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,48 @@ TEST_CASE("SparseTermDatacell Basic Test", "[ut][SparseTermDatacell]") {
149149
delete[] query_sv.ids_;
150150
delete[] query_sv.vals_;
151151
}
152+
153+
TEST_CASE("SparseTermDatacell Last Term Test", "[ut][SparseTermDatacell]") {
154+
auto allocator = SafeAllocator::FactoryDefaultAllocator();
155+
156+
auto make_sv = [](const std::vector<uint32_t>& ids, const std::vector<float>& vals) {
157+
vsag::SparseVector sv;
158+
sv.len_ = static_cast<uint32_t>(ids.size());
159+
sv.ids_ = const_cast<uint32_t*>(ids.data());
160+
sv.vals_ = const_cast<float*>(vals.data());
161+
return sv;
162+
};
163+
164+
std::vector<int64_t> ids = {0, 1};
165+
166+
{
167+
std::vector<uint32_t> ids0 = {1, 2};
168+
std::vector<float> vals0 = {0.1f, 0.0f};
169+
std::vector<uint32_t> ids1 = {1};
170+
std::vector<float> vals1 = {0.1f};
171+
172+
auto sv0 = make_sv(ids0, vals0);
173+
auto sv1 = make_sv(ids1, vals1);
174+
175+
auto data_cell =
176+
std::make_shared<SparseTermDataCell>(1, DEFAULT_TERM_ID_LIMIT, allocator.get());
177+
data_cell->InsertVector(sv0, ids[0]);
178+
data_cell->InsertVector(sv1, ids[1]);
179+
180+
std::vector<uint32_t> q_ids = {1, 4};
181+
std::vector<float> q_vals = {1.0f, 1.0f};
182+
auto sv_query = make_sv(q_ids, q_vals);
183+
184+
SINDISearchParameter search_params;
185+
search_params.term_prune_ratio = 0;
186+
search_params.query_prune_ratio = 0;
187+
auto computer =
188+
std::make_shared<SparseTermComputer>(sv_query, search_params, allocator.get());
189+
190+
std::vector<float> dists(2, 0);
191+
data_cell->Query(dists.data(), computer);
192+
193+
REQUIRE(std::abs(dists[0] - (-0.1f)) < 1e-3f);
194+
REQUIRE(std::abs(dists[1] - (-0.1f)) < 1e-3f);
195+
}
196+
}

0 commit comments

Comments
 (0)