@@ -11,20 +11,25 @@ namespace {
1111
1212using idx_t = hnswlib::labeltype;
1313
14- bool pickIdsDivisibleByThree (unsigned int label_id) {
15- return label_id % 3 == 0 ;
16- }
17-
18- bool pickIdsDivisibleBySeven (unsigned int label_id) {
19- return label_id % 7 == 0 ;
20- }
14+ class PickDivisibleIds : public hnswlib ::BaseFilterFunctor {
15+ unsigned int divisor = 1 ;
16+ public:
17+ PickDivisibleIds (unsigned int divisor): divisor(divisor) {
18+ assert (divisor != 0 );
19+ }
20+ bool operator ()(idx_t label_id) {
21+ return label_id % divisor == 0 ;
22+ }
23+ };
2124
22- bool pickNothing (unsigned int label_id) {
23- return false ;
24- }
25+ class PickNothing : public hnswlib ::BaseFilterFunctor {
26+ public:
27+ bool operator ()(idx_t label_id) {
28+ return false ;
29+ }
30+ };
2531
26- template <typename filter_func_t >
27- void test_some_filtering (filter_func_t & filter_func, size_t div_num, size_t label_id_start) {
32+ void test_some_filtering (hnswlib::BaseFilterFunctor& filter_func, size_t div_num, size_t label_id_start) {
2833 int d = 4 ;
2934 idx_t n = 100 ;
3035 idx_t nq = 10 ;
@@ -45,8 +50,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
4550 }
4651
4752 hnswlib::L2Space space (d);
48- hnswlib::AlgorithmInterface<float , filter_func_t >* alg_brute = new hnswlib::BruteforceSearch<float , filter_func_t >(&space, 2 * n);
49- hnswlib::AlgorithmInterface<float , filter_func_t >* alg_hnsw = new hnswlib::HierarchicalNSW<float , filter_func_t >(&space, 2 * n);
53+ hnswlib::AlgorithmInterface<float >* alg_brute = new hnswlib::BruteforceSearch<float >(&space, 2 * n);
54+ hnswlib::AlgorithmInterface<float >* alg_hnsw = new hnswlib::HierarchicalNSW<float >(&space, 2 * n);
5055
5156 for (size_t i = 0 ; i < n; ++i) {
5257 // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
@@ -57,8 +62,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
5762 // test searchKnnCloserFirst of BruteforceSearch with filtering
5863 for (size_t j = 0 ; j < nq; ++j) {
5964 const void * p = query.data () + j * d;
60- auto gd = alg_brute->searchKnn (p, k, filter_func);
61- auto res = alg_brute->searchKnnCloserFirst (p, k, filter_func);
65+ auto gd = alg_brute->searchKnn (p, k, & filter_func);
66+ auto res = alg_brute->searchKnnCloserFirst (p, k, & filter_func);
6267 assert (gd.size () == res.size ());
6368 size_t t = gd.size ();
6469 while (!gd.empty ()) {
@@ -71,8 +76,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
7176 // test searchKnnCloserFirst of hnsw with filtering
7277 for (size_t j = 0 ; j < nq; ++j) {
7378 const void * p = query.data () + j * d;
74- auto gd = alg_hnsw->searchKnn (p, k, filter_func);
75- auto res = alg_hnsw->searchKnnCloserFirst (p, k, filter_func);
79+ auto gd = alg_hnsw->searchKnn (p, k, & filter_func);
80+ auto res = alg_hnsw->searchKnnCloserFirst (p, k, & filter_func);
7681 assert (gd.size () == res.size ());
7782 size_t t = gd.size ();
7883 while (!gd.empty ()) {
@@ -86,8 +91,7 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
8691 delete alg_hnsw;
8792}
8893
89- template <typename filter_func_t >
90- void test_none_filtering (filter_func_t & filter_func, size_t label_id_start) {
94+ void test_none_filtering (hnswlib::BaseFilterFunctor& filter_func, size_t label_id_start) {
9195 int d = 4 ;
9296 idx_t n = 100 ;
9397 idx_t nq = 10 ;
@@ -108,8 +112,8 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
108112 }
109113
110114 hnswlib::L2Space space (d);
111- hnswlib::AlgorithmInterface<float , filter_func_t >* alg_brute = new hnswlib::BruteforceSearch<float , filter_func_t >(&space, 2 * n);
112- hnswlib::AlgorithmInterface<float , filter_func_t >* alg_hnsw = new hnswlib::HierarchicalNSW<float , filter_func_t >(&space, 2 * n);
115+ hnswlib::AlgorithmInterface<float >* alg_brute = new hnswlib::BruteforceSearch<float >(&space, 2 * n);
116+ hnswlib::AlgorithmInterface<float >* alg_hnsw = new hnswlib::HierarchicalNSW<float >(&space, 2 * n);
113117
114118 for (size_t i = 0 ; i < n; ++i) {
115119 // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
@@ -120,17 +124,17 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
120124 // test searchKnnCloserFirst of BruteforceSearch with filtering
121125 for (size_t j = 0 ; j < nq; ++j) {
122126 const void * p = query.data () + j * d;
123- auto gd = alg_brute->searchKnn (p, k, filter_func);
124- auto res = alg_brute->searchKnnCloserFirst (p, k, filter_func);
127+ auto gd = alg_brute->searchKnn (p, k, & filter_func);
128+ auto res = alg_brute->searchKnnCloserFirst (p, k, & filter_func);
125129 assert (gd.size () == res.size ());
126130 assert (0 == gd.size ());
127131 }
128132
129133 // test searchKnnCloserFirst of hnsw with filtering
130134 for (size_t j = 0 ; j < nq; ++j) {
131135 const void * p = query.data () + j * d;
132- auto gd = alg_hnsw->searchKnn (p, k, filter_func);
133- auto res = alg_hnsw->searchKnnCloserFirst (p, k, filter_func);
136+ auto gd = alg_hnsw->searchKnn (p, k, & filter_func);
137+ auto res = alg_hnsw->searchKnnCloserFirst (p, k, & filter_func);
134138 assert (gd.size () == res.size ());
135139 assert (0 == gd.size ());
136140 }
@@ -141,13 +145,13 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
141145
142146} // namespace
143147
144- class CustomFilterFunctor : public hnswlib ::FilterFunctor {
145- std::unordered_set<unsigned int > allowed_values;
148+ class CustomFilterFunctor : public hnswlib ::BaseFilterFunctor {
149+ std::unordered_set<idx_t > allowed_values;
146150
147151 public:
148- explicit CustomFilterFunctor (const std::unordered_set<unsigned int >& values) : allowed_values(values) {}
152+ explicit CustomFilterFunctor (const std::unordered_set<idx_t >& values) : allowed_values(values) {}
149153
150- bool operator ()(unsigned int id) {
154+ bool operator ()(idx_t id) {
151155 return allowed_values.count (id) != 0 ;
152156 }
153157};
@@ -156,10 +160,13 @@ int main() {
156160 std::cout << " Testing ..." << std::endl;
157161
158162 // some of the elements are filtered
163+ PickDivisibleIds pickIdsDivisibleByThree (3 );
159164 test_some_filtering (pickIdsDivisibleByThree, 3 , 17 );
165+ PickDivisibleIds pickIdsDivisibleBySeven (7 );
160166 test_some_filtering (pickIdsDivisibleBySeven, 7 , 17 );
161167
162168 // all of the elements are filtered
169+ PickNothing pickNothing;
163170 test_none_filtering (pickNothing, 17 );
164171
165172 // functor style which can capture context
0 commit comments