Skip to content

Commit 4a11c33

Browse files
anntiansAnnTian Shao
andauthored
Add multi-vector-support faiss patch to IndexHNSW::search_level_0 (opensearch-project#2647)
* Add multi-vector-support faiss patch to IndexHNSW::search_level_0 Signed-off-by: AnnTian Shao <anntians@amazon.com> * Add tests to JNI and KNN Signed-off-by: AnnTian Shao <anntians@amazon.com> * Update tests by adding hnsw cagra index binary and remove JNI layer method updateIndexSettings Signed-off-by: AnnTian Shao <anntians@amazon.com> * test fixes Signed-off-by: AnnTian Shao <anntians@amazon.com> --------- Signed-off-by: AnnTian Shao <anntians@amazon.com> Co-authored-by: AnnTian Shao <anntians@amazon.com>
1 parent f028757 commit 4a11c33

File tree

6 files changed

+463
-0
lines changed

6 files changed

+463
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1515
### Bug Fixes
1616
* Fixing bug to prevent NullPointerException while doing PUT mappings [#2556](https://github.com/opensearch-project/k-NN/issues/2556)
1717
* Add index operation listener to update translog source [#2629](https://github.com/opensearch-project/k-NN/pull/2629)
18+
* Add parent join support for faiss hnsw cagra [#2647](https://github.com/opensearch-project/k-NN/pull/2647)
1819
* [Remote Vector Index Build] Fix bug to support `COSINESIMIL` space type [#2627](https://github.com/opensearch-project/k-NN/pull/2627)
1920
### Infrastructure
2021
* Add github action to run ITs against remote index builder [2620](https://github.com/opensearch-project/k-NN/pull/2620)

jni/cmake/init-faiss.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ if(NOT DEFINED APPLY_LIB_PATCHES OR "${APPLY_LIB_PATCHES}" STREQUAL true)
2020
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0002-Enable-precomp-table-to-be-shared-ivfpq.patch")
2121
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0003-Custom-patch-to-support-range-search-params.patch")
2222
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0004-Custom-patch-to-support-binary-vector.patch")
23+
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0005-Custom-patch-to-support-multi-vector-IndexHNSW-search_level_0.patch")
2324

2425
# Get patch id of the last commit
2526
execute_process(COMMAND sh -c "git --no-pager show HEAD | git patch-id --stable" OUTPUT_VARIABLE PATCH_ID_OUTPUT_FROM_COMMIT WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss)
Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
From 9ef5e349ca5893da07898d7f1d22b0a81f17fddc Mon Sep 17 00:00:00 2001
2+
From: AnnTian Shao <anntians@amazon.com>
3+
Date: Thu, 3 Apr 2025 21:21:11 +0000
4+
Subject: [PATCH] Add multi-vector-support faiss patch to
5+
IndexHNSW::search_level_0
6+
7+
Signed-off-by: AnnTian Shao <anntians@amazon.com>
8+
---
9+
faiss/IndexHNSW.cpp | 123 +++++++++++++++++++++++++-----------
10+
faiss/index_factory.cpp | 7 ++-
11+
tests/test_id_grouper.cpp | 128 ++++++++++++++++++++++++++++++++++++++
12+
3 files changed, 222 insertions(+), 36 deletions(-)
13+
14+
diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp
15+
index eee3e99c6..7c5dfe020 100644
16+
--- a/faiss/IndexHNSW.cpp
17+
+++ b/faiss/IndexHNSW.cpp
18+
@@ -286,6 +286,61 @@ void hnsw_search(
19+
hnsw_stats.combine({n1, n2, ndis, nhops});
20+
}
21+
22+
+template <class BlockResultHandler>
23+
+void hnsw_search_level_0(
24+
+ const IndexHNSW* index,
25+
+ idx_t n,
26+
+ const float* x,
27+
+ idx_t k,
28+
+ const storage_idx_t* nearest,
29+
+ const float* nearest_d,
30+
+ float* distances,
31+
+ idx_t* labels,
32+
+ int nprobe,
33+
+ int search_type,
34+
+ const SearchParameters* params_in,
35+
+ BlockResultHandler& bres) {
36+
+
37+
+ const HNSW& hnsw = index->hnsw;
38+
+ const SearchParametersHNSW* params = nullptr;
39+
+
40+
+ if (params_in) {
41+
+ params = dynamic_cast<const SearchParametersHNSW*>(params_in);
42+
+ FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
43+
+ }
44+
+
45+
+#pragma omp parallel
46+
+ {
47+
+ std::unique_ptr<DistanceComputer> qdis(
48+
+ storage_distance_computer(index->storage));
49+
+ HNSWStats search_stats;
50+
+ VisitedTable vt(index->ntotal);
51+
+ typename BlockResultHandler::SingleResultHandler res(bres);
52+
+
53+
+#pragma omp for
54+
+ for (idx_t i = 0; i < n; i++) {
55+
+ res.begin(i);
56+
+ qdis->set_query(x + i * index->d);
57+
+
58+
+ hnsw.search_level_0(
59+
+ *qdis.get(),
60+
+ res,
61+
+ nprobe,
62+
+ nearest + i * nprobe,
63+
+ nearest_d + i * nprobe,
64+
+ search_type,
65+
+ search_stats,
66+
+ vt,
67+
+ params);
68+
+ res.end();
69+
+ vt.advance();
70+
+ }
71+
+#pragma omp critical
72+
+ { hnsw_stats.combine(search_stats); }
73+
+ }
74+
+
75+
+}
76+
+
77+
} // anonymous namespace
78+
79+
void IndexHNSW::search(
80+
@@ -419,46 +474,44 @@ void IndexHNSW::search_level_0(
81+
FAISS_THROW_IF_NOT(k > 0);
82+
FAISS_THROW_IF_NOT(nprobe > 0);
83+
84+
- const SearchParametersHNSW* params = nullptr;
85+
-
86+
- if (params_in) {
87+
- params = dynamic_cast<const SearchParametersHNSW*>(params_in);
88+
- FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
89+
- }
90+
-
91+
storage_idx_t ntotal = hnsw.levels.size();
92+
93+
- using RH = HeapBlockResultHandler<HNSW::C>;
94+
- RH bres(n, distances, labels, k);
95+
96+
-#pragma omp parallel
97+
- {
98+
- std::unique_ptr<DistanceComputer> qdis(
99+
- storage_distance_computer(storage));
100+
- HNSWStats search_stats;
101+
- VisitedTable vt(ntotal);
102+
- RH::SingleResultHandler res(bres);
103+
+ if (params_in && params_in->grp) {
104+
+ using RH = GroupedHeapBlockResultHandler<HNSW::C>;
105+
+ RH bres(n, distances, labels, k, params_in->grp);
106+
107+
-#pragma omp for
108+
- for (idx_t i = 0; i < n; i++) {
109+
- res.begin(i);
110+
- qdis->set_query(x + i * d);
111+
112+
- hnsw.search_level_0(
113+
- *qdis.get(),
114+
- res,
115+
- nprobe,
116+
- nearest + i * nprobe,
117+
- nearest_d + i * nprobe,
118+
- search_type,
119+
- search_stats,
120+
- vt,
121+
- params);
122+
- res.end();
123+
- vt.advance();
124+
- }
125+
-#pragma omp critical
126+
- { hnsw_stats.combine(search_stats); }
127+
+ hnsw_search_level_0(
128+
+ this,
129+
+ n,
130+
+ x,
131+
+ k,
132+
+ nearest,
133+
+ nearest_d,
134+
+ distances,
135+
+ labels,
136+
+ nprobe, // n_probes
137+
+ search_type, // search_type
138+
+ params_in,
139+
+ bres);
140+
+ } else {
141+
+ using RH = HeapBlockResultHandler<HNSW::C>;
142+
+ RH bres(n, distances, labels, k);
143+
+
144+
+ hnsw_search_level_0(
145+
+ this,
146+
+ n,
147+
+ x,
148+
+ k,
149+
+ nearest,
150+
+ nearest_d,
151+
+ distances,
152+
+ labels,
153+
+ nprobe, // n_probes
154+
+ search_type, // search_type
155+
+ params_in,
156+
+ bres);
157+
}
158+
if (is_similarity_metric(this->metric_type)) {
159+
// we need to revert the negated distances
160+
diff --git a/faiss/index_factory.cpp b/faiss/index_factory.cpp
161+
index 8ff4bfec7..24e65b632 100644
162+
--- a/faiss/index_factory.cpp
163+
+++ b/faiss/index_factory.cpp
164+
@@ -453,6 +453,11 @@ IndexHNSW* parse_IndexHNSW(
165+
return re_match(code_string, pattern, sm);
166+
};
167+
168+
+ if (match("Cagra")) {
169+
+ IndexHNSWCagra* cagra = new IndexHNSWCagra(d, hnsw_M, mt);
170+
+ return cagra;
171+
+ }
172+
+
173+
if (match("Flat|")) {
174+
return new IndexHNSWFlat(d, hnsw_M, mt);
175+
}
176+
@@ -781,7 +786,7 @@ std::unique_ptr<Index> index_factory_sub(
177+
178+
// HNSW variants (it was unclear in the old version that the separator was a
179+
// "," so we support both "_" and ",")
180+
- if (re_match(description, "HNSW([0-9]*)([,_].*)?", sm)) {
181+
+ if (re_match(description, "HNSW([0-9]*)([,_].*)?(Cagra)?", sm)) {
182+
int hnsw_M = mres_to_int(sm[1], 32);
183+
// We also accept empty code string (synonym of Flat)
184+
std::string code_string =
185+
diff --git a/tests/test_id_grouper.cpp b/tests/test_id_grouper.cpp
186+
index bd8ab5f9d..ebe16a364 100644
187+
--- a/tests/test_id_grouper.cpp
188+
+++ b/tests/test_id_grouper.cpp
189+
@@ -172,6 +172,65 @@ TEST(IdGrouper, bitmap_with_hnsw) {
190+
delete[] xb;
191+
}
192+
193+
+TEST(IdGrouper, bitmap_with_hnsw_cagra) {
194+
+ int d = 1; // dimension
195+
+ int nb = 10; // database size
196+
+
197+
+ std::mt19937 rng;
198+
+ std::uniform_real_distribution<> distrib;
199+
+
200+
+ float* xb = new float[d * nb];
201+
+
202+
+ for (int i = 0; i < nb; i++) {
203+
+ for (int j = 0; j < d; j++)
204+
+ xb[d * i + j] = distrib(rng);
205+
+ xb[d * i] += i / 1000.;
206+
+ }
207+
+
208+
+ uint64_t bitmap[1] = {};
209+
+ faiss::IDGrouperBitmap id_grouper(1, bitmap);
210+
+ for (int i = 0; i < nb; i++) {
211+
+ if (i % 2 == 1) {
212+
+ id_grouper.set_group(i);
213+
+ }
214+
+ }
215+
+
216+
+ int k = 10;
217+
+ int m = 8;
218+
+ faiss::Index* index =
219+
+ new faiss::IndexHNSWCagra(d, m, faiss::MetricType::METRIC_L2);
220+
+ index->add(nb, xb); // add vectors to the index
221+
+ dynamic_cast<faiss::IndexHNSWCagra*>(index)->base_level_only=true;
222+
+
223+
+ // search
224+
+ idx_t* I = new idx_t[k];
225+
+ float* D = new float[k];
226+
+
227+
+ auto pSearchParameters = new faiss::SearchParametersHNSW();
228+
+ pSearchParameters->grp = &id_grouper;
229+
+
230+
+ index->search(1, xb, k, D, I, pSearchParameters);
231+
+
232+
+ std::unordered_set<int> group_ids;
233+
+ ASSERT_EQ(0, I[0]);
234+
+ ASSERT_EQ(0, D[0]);
235+
+ group_ids.insert(id_grouper.get_group(I[0]));
236+
+ for (int j = 1; j < 5; j++) {
237+
+ ASSERT_NE(-1, I[j]);
238+
+ ASSERT_NE(std::numeric_limits<float>::max(), D[j]);
239+
+ group_ids.insert(id_grouper.get_group(I[j]));
240+
+ }
241+
+ for (int j = 5; j < k; j++) {
242+
+ ASSERT_EQ(-1, I[j]);
243+
+ ASSERT_EQ(std::numeric_limits<float>::max(), D[j]);
244+
+ }
245+
+ ASSERT_EQ(5, group_ids.size());
246+
+
247+
+ delete[] I;
248+
+ delete[] D;
249+
+ delete[] xb;
250+
+}
251+
+
252+
TEST(IdGrouper, bitmap_with_binary_hnsw) {
253+
int d = 16; // dimension
254+
int nb = 10; // database size
255+
@@ -291,6 +350,75 @@ TEST(IdGrouper, bitmap_with_hnsw_idmap) {
256+
delete[] xb;
257+
}
258+
259+
+TEST(IdGrouper, bitmap_with_hnsw_cagra_idmap) {
260+
+ int d = 1; // dimension
261+
+ int nb = 10; // database size
262+
+
263+
+ std::mt19937 rng;
264+
+ std::uniform_real_distribution<> distrib;
265+
+
266+
+ float* xb = new float[d * nb];
267+
+ idx_t* xids = new idx_t[d * nb];
268+
+
269+
+ for (int i = 0; i < nb; i++) {
270+
+ for (int j = 0; j < d; j++)
271+
+ xb[d * i + j] = distrib(rng);
272+
+ xb[d * i] += i / 1000.;
273+
+ }
274+
+
275+
+ uint64_t bitmap[1] = {};
276+
+ faiss::IDGrouperBitmap id_grouper(1, bitmap);
277+
+ int num_grp = 0;
278+
+ int grp_size = 2;
279+
+ int id_in_grp = 0;
280+
+ for (int i = 0; i < nb; i++) {
281+
+ xids[i] = i + num_grp;
282+
+ id_in_grp++;
283+
+ if (id_in_grp == grp_size) {
284+
+ id_grouper.set_group(i + num_grp + 1);
285+
+ num_grp++;
286+
+ id_in_grp = 0;
287+
+ }
288+
+ }
289+
+
290+
+ int k = 10;
291+
+ int m = 8;
292+
+ faiss::Index* index =
293+
+ new faiss::IndexHNSWCagra(d, m, faiss::MetricType::METRIC_L2);
294+
+ faiss::IndexIDMap id_map =
295+
+ faiss::IndexIDMap(index); // add vectors to the index
296+
+ id_map.add_with_ids(nb, xb, xids);
297+
+ dynamic_cast<faiss::IndexHNSWCagra*>(id_map.index)->base_level_only=true;
298+
+
299+
+ // search
300+
+ idx_t* I = new idx_t[k];
301+
+ float* D = new float[k];
302+
+
303+
+ auto pSearchParameters = new faiss::SearchParametersHNSW();
304+
+ pSearchParameters->grp = &id_grouper;
305+
+
306+
+ id_map.search(1, xb, k, D, I, pSearchParameters);
307+
+
308+
+ std::unordered_set<int> group_ids;
309+
+ ASSERT_EQ(0, I[0]);
310+
+ ASSERT_EQ(0, D[0]);
311+
+ group_ids.insert(id_grouper.get_group(I[0]));
312+
+ for (int j = 1; j < 5; j++) {
313+
+ ASSERT_NE(-1, I[j]);
314+
+ ASSERT_NE(std::numeric_limits<float>::max(), D[j]);
315+
+ group_ids.insert(id_grouper.get_group(I[j]));
316+
+ }
317+
+ for (int j = 5; j < k; j++) {
318+
+ ASSERT_EQ(-1, I[j]);
319+
+ ASSERT_EQ(std::numeric_limits<float>::max(), D[j]);
320+
+ }
321+
+ ASSERT_EQ(5, group_ids.size());
322+
+
323+
+ delete[] I;
324+
+ delete[] D;
325+
+ delete[] xb;
326+
+}
327+
+
328+
TEST(IdGrouper, bitmap_with_binary_hnsw_idmap) {
329+
int d = 16; // dimension
330+
int nb = 10; // database size
331+
--
332+
2.47.1
333+

0 commit comments

Comments
 (0)