Skip to content

Commit 5f7e55b

Browse files
committed
feat: overloading for IVF nprobe
1 parent 82d4aeb commit 5f7e55b

File tree

4 files changed

+45
-15
lines changed

4 files changed

+45
-15
lines changed

build/zenann.cpython-311-darwin.so

382 KB
Binary file not shown.

include/zenann/IVFFlatIndex.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ class IVFFlatIndex : public IndexBase {
1111
~IVFFlatIndex() override;
1212
void train() override;
1313
SearchResult search(const Vector& query, size_t k) const override;
14+
SearchResult search(const Vector& query, size_t k, size_t nprobe) const;
1415
std::vector<SearchResult> search_batch(const Dataset& queries, size_t k) const;
16+
std::vector<SearchResult> search_batch(const Dataset& queries, size_t k, size_t nprobe) const;
1517
void write_index(const std::string& filename) const;
1618
static std::shared_ptr<IVFFlatIndex> read_index(const std::string& filename);
1719
private:

python/zenann_pybind.cpp

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,33 @@ PYBIND11_MODULE(zenann, m) {
3737

3838
// Bind IVFFlatIndex
3939
py::class_<IVFFlatIndex, IndexBase, std::shared_ptr<IVFFlatIndex>>(m, "IVFFlatIndex")
40-
.def(py::init<size_t, size_t, size_t>(),
41-
py::arg("dim"), py::arg("nlist"), py::arg("nprobe") = 1)
42-
.def("build", &IVFFlatIndex::build, py::arg("data"),
43-
"Add data and train the IVF index")
44-
.def("train", &IVFFlatIndex::train,
45-
"Train the IVF index (run K-means and build inverted lists)")
46-
.def("search", &IVFFlatIndex::search, py::arg("query"), py::arg("k"),
47-
"Search top-k nearest neighbors")
48-
.def("search_batch", &IVFFlatIndex::search_batch, py::arg("queries"), py::arg("k"),
49-
"Search top-k for a batch of queries")
50-
// Persistence API
51-
.def("write_index", &IVFFlatIndex::write_index, py::arg("filename"),
52-
"Serialize the index to a file")
53-
.def_static("read_index", &IVFFlatIndex::read_index, py::arg("filename"),
54-
"Load an index from a file");
40+
.def(py::init<size_t, size_t, size_t>(),
41+
py::arg("dim"), py::arg("nlist"), py::arg("nprobe") = 1)
42+
.def("build", &IVFFlatIndex::build, py::arg("data"))
43+
.def("train", &IVFFlatIndex::train)
44+
.def("search",
45+
py::overload_cast<const Vector&, size_t>(
46+
&IVFFlatIndex::search,
47+
py::const_),
48+
py::arg("query"), py::arg("k"))
49+
.def("search",
50+
py::overload_cast<const Vector&, size_t, size_t>(
51+
&IVFFlatIndex::search,
52+
py::const_),
53+
py::arg("query"), py::arg("k"), py::arg("nprobe"))
54+
.def("search_batch",
55+
py::overload_cast<const Dataset&, size_t>(
56+
&IVFFlatIndex::search_batch,
57+
py::const_),
58+
py::arg("queries"), py::arg("k"))
59+
.def("search_batch",
60+
py::overload_cast<const Dataset&, size_t, size_t>(
61+
&IVFFlatIndex::search_batch,
62+
py::const_),
63+
py::arg("queries"), py::arg("k"), py::arg("nprobe"))
64+
65+
.def("write_index", &IVFFlatIndex::write_index, py::arg("filename"))
66+
.def_static("read_index", &IVFFlatIndex::read_index, py::arg("filename"));
5567

5668
py::class_<KDTreeIndex, IndexBase, std::shared_ptr<KDTreeIndex>>(m, "KDTreeIndex")
5769
.def(py::init<size_t>(), py::arg("dim"))

src/IVFFlatIndex.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,14 @@ SearchResult IVFFlatIndex::search(const Vector& query, size_t k) const {
9595
return result;
9696
}
9797

98+
SearchResult IVFFlatIndex::search(const Vector& query, size_t k, size_t nprobe) const {
99+
size_t old = this->nprobe_;
100+
const_cast<IVFFlatIndex*>(this)->nprobe_ = nprobe;
101+
SearchResult res = this->search(query, k);
102+
const_cast<IVFFlatIndex*>(this)->nprobe_ = old;
103+
return res;
104+
}
105+
98106
std::vector<SearchResult> IVFFlatIndex::search_batch(const Dataset& queries, size_t k) const {
99107
std::vector<SearchResult> results;
100108
results.reserve(queries.size());
@@ -104,6 +112,14 @@ std::vector<SearchResult> IVFFlatIndex::search_batch(const Dataset& queries, siz
104112
return results;
105113
}
106114

115+
std::vector<SearchResult> IVFFlatIndex::search_batch(const Dataset& queries, size_t k, size_t nprobe) const {
116+
size_t old = this->nprobe_;
117+
const_cast<IVFFlatIndex*>(this)->nprobe_ = nprobe;
118+
auto results = this->search_batch(queries, k);
119+
const_cast<IVFFlatIndex*>(this)->nprobe_ = old;
120+
return results;
121+
}
122+
107123
void IVFFlatIndex::kmeans(const Dataset& data, size_t iterations) {
108124
size_t n = data.size();
109125
std::mt19937 rng(123);

0 commit comments

Comments
 (0)