|
| 1 | +From 9ef5e349ca5893da07898d7f1d22b0a81f17fddc Mon Sep 17 00:00:00 2001 |
| 2 | +From: AnnTian Shao <anntians@amazon.com> |
| 3 | +Date: Thu, 3 Apr 2025 21:21:11 +0000 |
| 4 | +Subject: [PATCH] Add multi-vector-support faiss patch to |
| 5 | + IndexHNSW::search_level_0 |
| 6 | + |
| 7 | +Signed-off-by: AnnTian Shao <anntians@amazon.com> |
| 8 | +--- |
| 9 | + faiss/IndexHNSW.cpp | 123 +++++++++++++++++++++++++----------- |
| 10 | + faiss/index_factory.cpp | 7 ++- |
| 11 | + tests/test_id_grouper.cpp | 128 ++++++++++++++++++++++++++++++++++++++ |
| 12 | + 3 files changed, 222 insertions(+), 36 deletions(-) |
| 13 | + |
| 14 | +diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp |
| 15 | +index eee3e99c6..7c5dfe020 100644 |
| 16 | +--- a/faiss/IndexHNSW.cpp |
| 17 | ++++ b/faiss/IndexHNSW.cpp |
| 18 | +@@ -286,6 +286,61 @@ void hnsw_search( |
| 19 | + hnsw_stats.combine({n1, n2, ndis, nhops}); |
| 20 | + } |
| 21 | + |
| 22 | ++template <class BlockResultHandler> |
| 23 | ++void hnsw_search_level_0( |
| 24 | ++ const IndexHNSW* index, |
| 25 | ++ idx_t n, |
| 26 | ++ const float* x, |
| 27 | ++ idx_t k, |
| 28 | ++ const storage_idx_t* nearest, |
| 29 | ++ const float* nearest_d, |
| 30 | ++ float* distances, |
| 31 | ++ idx_t* labels, |
| 32 | ++ int nprobe, |
| 33 | ++ int search_type, |
| 34 | ++ const SearchParameters* params_in, |
| 35 | ++ BlockResultHandler& bres) { |
| 36 | ++ |
| 37 | ++ const HNSW& hnsw = index->hnsw; |
| 38 | ++ const SearchParametersHNSW* params = nullptr; |
| 39 | ++ |
| 40 | ++ if (params_in) { |
| 41 | ++ params = dynamic_cast<const SearchParametersHNSW*>(params_in); |
| 42 | ++ FAISS_THROW_IF_NOT_MSG(params, "params type invalid"); |
| 43 | ++ } |
| 44 | ++ |
| 45 | ++#pragma omp parallel |
| 46 | ++ { |
| 47 | ++ std::unique_ptr<DistanceComputer> qdis( |
| 48 | ++ storage_distance_computer(index->storage)); |
| 49 | ++ HNSWStats search_stats; |
| 50 | ++ VisitedTable vt(index->ntotal); |
| 51 | ++ typename BlockResultHandler::SingleResultHandler res(bres); |
| 52 | ++ |
| 53 | ++#pragma omp for |
| 54 | ++ for (idx_t i = 0; i < n; i++) { |
| 55 | ++ res.begin(i); |
| 56 | ++ qdis->set_query(x + i * index->d); |
| 57 | ++ |
| 58 | ++ hnsw.search_level_0( |
| 59 | ++ *qdis.get(), |
| 60 | ++ res, |
| 61 | ++ nprobe, |
| 62 | ++ nearest + i * nprobe, |
| 63 | ++ nearest_d + i * nprobe, |
| 64 | ++ search_type, |
| 65 | ++ search_stats, |
| 66 | ++ vt, |
| 67 | ++ params); |
| 68 | ++ res.end(); |
| 69 | ++ vt.advance(); |
| 70 | ++ } |
| 71 | ++#pragma omp critical |
| 72 | ++ { hnsw_stats.combine(search_stats); } |
| 73 | ++ } |
| 74 | ++ |
| 75 | ++} |
| 76 | ++ |
| 77 | + } // anonymous namespace |
| 78 | + |
| 79 | + void IndexHNSW::search( |
| 80 | +@@ -419,46 +474,44 @@ void IndexHNSW::search_level_0( |
| 81 | + FAISS_THROW_IF_NOT(k > 0); |
| 82 | + FAISS_THROW_IF_NOT(nprobe > 0); |
| 83 | + |
| 84 | +- const SearchParametersHNSW* params = nullptr; |
| 85 | +- |
| 86 | +- if (params_in) { |
| 87 | +- params = dynamic_cast<const SearchParametersHNSW*>(params_in); |
| 88 | +- FAISS_THROW_IF_NOT_MSG(params, "params type invalid"); |
| 89 | +- } |
| 90 | +- |
| 91 | + storage_idx_t ntotal = hnsw.levels.size(); |
| 92 | + |
| 93 | +- using RH = HeapBlockResultHandler<HNSW::C>; |
| 94 | +- RH bres(n, distances, labels, k); |
| 95 | + |
| 96 | +-#pragma omp parallel |
| 97 | +- { |
| 98 | +- std::unique_ptr<DistanceComputer> qdis( |
| 99 | +- storage_distance_computer(storage)); |
| 100 | +- HNSWStats search_stats; |
| 101 | +- VisitedTable vt(ntotal); |
| 102 | +- RH::SingleResultHandler res(bres); |
| 103 | ++ if (params_in && params_in->grp) { |
| 104 | ++ using RH = GroupedHeapBlockResultHandler<HNSW::C>; |
| 105 | ++ RH bres(n, distances, labels, k, params_in->grp); |
| 106 | + |
| 107 | +-#pragma omp for |
| 108 | +- for (idx_t i = 0; i < n; i++) { |
| 109 | +- res.begin(i); |
| 110 | +- qdis->set_query(x + i * d); |
| 111 | + |
| 112 | +- hnsw.search_level_0( |
| 113 | +- *qdis.get(), |
| 114 | +- res, |
| 115 | +- nprobe, |
| 116 | +- nearest + i * nprobe, |
| 117 | +- nearest_d + i * nprobe, |
| 118 | +- search_type, |
| 119 | +- search_stats, |
| 120 | +- vt, |
| 121 | +- params); |
| 122 | +- res.end(); |
| 123 | +- vt.advance(); |
| 124 | +- } |
| 125 | +-#pragma omp critical |
| 126 | +- { hnsw_stats.combine(search_stats); } |
| 127 | ++ hnsw_search_level_0( |
| 128 | ++ this, |
| 129 | ++ n, |
| 130 | ++ x, |
| 131 | ++ k, |
| 132 | ++ nearest, |
| 133 | ++ nearest_d, |
| 134 | ++ distances, |
| 135 | ++ labels, |
| 136 | ++ nprobe, // n_probes |
| 137 | ++ search_type, // search_type |
| 138 | ++ params_in, |
| 139 | ++ bres); |
| 140 | ++ } else { |
| 141 | ++ using RH = HeapBlockResultHandler<HNSW::C>; |
| 142 | ++ RH bres(n, distances, labels, k); |
| 143 | ++ |
| 144 | ++ hnsw_search_level_0( |
| 145 | ++ this, |
| 146 | ++ n, |
| 147 | ++ x, |
| 148 | ++ k, |
| 149 | ++ nearest, |
| 150 | ++ nearest_d, |
| 151 | ++ distances, |
| 152 | ++ labels, |
| 153 | ++ nprobe, // n_probes |
| 154 | ++ search_type, // search_type |
| 155 | ++ params_in, |
| 156 | ++ bres); |
| 157 | + } |
| 158 | + if (is_similarity_metric(this->metric_type)) { |
| 159 | + // we need to revert the negated distances |
| 160 | +diff --git a/faiss/index_factory.cpp b/faiss/index_factory.cpp |
| 161 | +index 8ff4bfec7..24e65b632 100644 |
| 162 | +--- a/faiss/index_factory.cpp |
| 163 | ++++ b/faiss/index_factory.cpp |
| 164 | +@@ -453,6 +453,11 @@ IndexHNSW* parse_IndexHNSW( |
| 165 | + return re_match(code_string, pattern, sm); |
| 166 | + }; |
| 167 | + |
| 168 | ++ if (match("Cagra")) { |
| 169 | ++ IndexHNSWCagra* cagra = new IndexHNSWCagra(d, hnsw_M, mt); |
| 170 | ++ return cagra; |
| 171 | ++ } |
| 172 | ++ |
| 173 | + if (match("Flat|")) { |
| 174 | + return new IndexHNSWFlat(d, hnsw_M, mt); |
| 175 | + } |
| 176 | +@@ -781,7 +786,7 @@ std::unique_ptr<Index> index_factory_sub( |
| 177 | + |
| 178 | + // HNSW variants (it was unclear in the old version that the separator was a |
| 179 | + // "," so we support both "_" and ",") |
| 180 | +- if (re_match(description, "HNSW([0-9]*)([,_].*)?", sm)) { |
| 181 | ++ if (re_match(description, "HNSW([0-9]*)([,_].*)?(Cagra)?", sm)) { |
| 182 | + int hnsw_M = mres_to_int(sm[1], 32); |
| 183 | + // We also accept empty code string (synonym of Flat) |
| 184 | + std::string code_string = |
| 185 | +diff --git a/tests/test_id_grouper.cpp b/tests/test_id_grouper.cpp |
| 186 | +index bd8ab5f9d..ebe16a364 100644 |
| 187 | +--- a/tests/test_id_grouper.cpp |
| 188 | ++++ b/tests/test_id_grouper.cpp |
| 189 | +@@ -172,6 +172,65 @@ TEST(IdGrouper, bitmap_with_hnsw) { |
| 190 | + delete[] xb; |
| 191 | + } |
| 192 | + |
| 193 | ++TEST(IdGrouper, bitmap_with_hnsw_cagra) { |
| 194 | ++ int d = 1; // dimension |
| 195 | ++ int nb = 10; // database size |
| 196 | ++ |
| 197 | ++ std::mt19937 rng; |
| 198 | ++ std::uniform_real_distribution<> distrib; |
| 199 | ++ |
| 200 | ++ float* xb = new float[d * nb]; |
| 201 | ++ |
| 202 | ++ for (int i = 0; i < nb; i++) { |
| 203 | ++ for (int j = 0; j < d; j++) |
| 204 | ++ xb[d * i + j] = distrib(rng); |
| 205 | ++ xb[d * i] += i / 1000.; |
| 206 | ++ } |
| 207 | ++ |
| 208 | ++ uint64_t bitmap[1] = {}; |
| 209 | ++ faiss::IDGrouperBitmap id_grouper(1, bitmap); |
| 210 | ++ for (int i = 0; i < nb; i++) { |
| 211 | ++ if (i % 2 == 1) { |
| 212 | ++ id_grouper.set_group(i); |
| 213 | ++ } |
| 214 | ++ } |
| 215 | ++ |
| 216 | ++ int k = 10; |
| 217 | ++ int m = 8; |
| 218 | ++ faiss::Index* index = |
| 219 | ++ new faiss::IndexHNSWCagra(d, m, faiss::MetricType::METRIC_L2); |
| 220 | ++ index->add(nb, xb); // add vectors to the index |
| 221 | ++ dynamic_cast<faiss::IndexHNSWCagra*>(index)->base_level_only=true; |
| 222 | ++ |
| 223 | ++ // search |
| 224 | ++ idx_t* I = new idx_t[k]; |
| 225 | ++ float* D = new float[k]; |
| 226 | ++ |
| 227 | ++ auto pSearchParameters = new faiss::SearchParametersHNSW(); |
| 228 | ++ pSearchParameters->grp = &id_grouper; |
| 229 | ++ |
| 230 | ++ index->search(1, xb, k, D, I, pSearchParameters); |
| 231 | ++ |
| 232 | ++ std::unordered_set<int> group_ids; |
| 233 | ++ ASSERT_EQ(0, I[0]); |
| 234 | ++ ASSERT_EQ(0, D[0]); |
| 235 | ++ group_ids.insert(id_grouper.get_group(I[0])); |
| 236 | ++ for (int j = 1; j < 5; j++) { |
| 237 | ++ ASSERT_NE(-1, I[j]); |
| 238 | ++ ASSERT_NE(std::numeric_limits<float>::max(), D[j]); |
| 239 | ++ group_ids.insert(id_grouper.get_group(I[j])); |
| 240 | ++ } |
| 241 | ++ for (int j = 5; j < k; j++) { |
| 242 | ++ ASSERT_EQ(-1, I[j]); |
| 243 | ++ ASSERT_EQ(std::numeric_limits<float>::max(), D[j]); |
| 244 | ++ } |
| 245 | ++ ASSERT_EQ(5, group_ids.size()); |
| 246 | ++ |
| 247 | ++ delete[] I; |
| 248 | ++ delete[] D; |
| 249 | ++ delete[] xb; |
| 250 | ++} |
| 251 | ++ |
| 252 | + TEST(IdGrouper, bitmap_with_binary_hnsw) { |
| 253 | + int d = 16; // dimension |
| 254 | + int nb = 10; // database size |
| 255 | +@@ -291,6 +350,75 @@ TEST(IdGrouper, bitmap_with_hnsw_idmap) { |
| 256 | + delete[] xb; |
| 257 | + } |
| 258 | + |
| 259 | ++TEST(IdGrouper, bitmap_with_hnsw_cagra_idmap) { |
| 260 | ++ int d = 1; // dimension |
| 261 | ++ int nb = 10; // database size |
| 262 | ++ |
| 263 | ++ std::mt19937 rng; |
| 264 | ++ std::uniform_real_distribution<> distrib; |
| 265 | ++ |
| 266 | ++ float* xb = new float[d * nb]; |
| 267 | ++ idx_t* xids = new idx_t[d * nb]; |
| 268 | ++ |
| 269 | ++ for (int i = 0; i < nb; i++) { |
| 270 | ++ for (int j = 0; j < d; j++) |
| 271 | ++ xb[d * i + j] = distrib(rng); |
| 272 | ++ xb[d * i] += i / 1000.; |
| 273 | ++ } |
| 274 | ++ |
| 275 | ++ uint64_t bitmap[1] = {}; |
| 276 | ++ faiss::IDGrouperBitmap id_grouper(1, bitmap); |
| 277 | ++ int num_grp = 0; |
| 278 | ++ int grp_size = 2; |
| 279 | ++ int id_in_grp = 0; |
| 280 | ++ for (int i = 0; i < nb; i++) { |
| 281 | ++ xids[i] = i + num_grp; |
| 282 | ++ id_in_grp++; |
| 283 | ++ if (id_in_grp == grp_size) { |
| 284 | ++ id_grouper.set_group(i + num_grp + 1); |
| 285 | ++ num_grp++; |
| 286 | ++ id_in_grp = 0; |
| 287 | ++ } |
| 288 | ++ } |
| 289 | ++ |
| 290 | ++ int k = 10; |
| 291 | ++ int m = 8; |
| 292 | ++ faiss::Index* index = |
| 293 | ++ new faiss::IndexHNSWCagra(d, m, faiss::MetricType::METRIC_L2); |
| 294 | ++ faiss::IndexIDMap id_map = |
| 295 | ++ faiss::IndexIDMap(index); // add vectors to the index |
| 296 | ++ id_map.add_with_ids(nb, xb, xids); |
| 297 | ++ dynamic_cast<faiss::IndexHNSWCagra*>(id_map.index)->base_level_only=true; |
| 298 | ++ |
| 299 | ++ // search |
| 300 | ++ idx_t* I = new idx_t[k]; |
| 301 | ++ float* D = new float[k]; |
| 302 | ++ |
| 303 | ++ auto pSearchParameters = new faiss::SearchParametersHNSW(); |
| 304 | ++ pSearchParameters->grp = &id_grouper; |
| 305 | ++ |
| 306 | ++ id_map.search(1, xb, k, D, I, pSearchParameters); |
| 307 | ++ |
| 308 | ++ std::unordered_set<int> group_ids; |
| 309 | ++ ASSERT_EQ(0, I[0]); |
| 310 | ++ ASSERT_EQ(0, D[0]); |
| 311 | ++ group_ids.insert(id_grouper.get_group(I[0])); |
| 312 | ++ for (int j = 1; j < 5; j++) { |
| 313 | ++ ASSERT_NE(-1, I[j]); |
| 314 | ++ ASSERT_NE(std::numeric_limits<float>::max(), D[j]); |
| 315 | ++ group_ids.insert(id_grouper.get_group(I[j])); |
| 316 | ++ } |
| 317 | ++ for (int j = 5; j < k; j++) { |
| 318 | ++ ASSERT_EQ(-1, I[j]); |
| 319 | ++ ASSERT_EQ(std::numeric_limits<float>::max(), D[j]); |
| 320 | ++ } |
| 321 | ++ ASSERT_EQ(5, group_ids.size()); |
| 322 | ++ |
| 323 | ++ delete[] I; |
| 324 | ++ delete[] D; |
| 325 | ++ delete[] xb; |
| 326 | ++} |
| 327 | ++ |
| 328 | + TEST(IdGrouper, bitmap_with_binary_hnsw_idmap) { |
| 329 | + int d = 16; // dimension |
| 330 | + int nb = 10; // database size |
| 331 | +-- |
| 332 | +2.47.1 |
| 333 | + |
0 commit comments