Skip to content

Commit 263fa28

Browse files
committed
feat: pybind wrapper for KDTree
1 parent c15604b commit 263fa28

File tree

2 files changed

+27
-10
lines changed

2 files changed

+27
-10
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ PROJECT_INCLUDE := -I./include -I./include/zenann
99
ALL_INCLUDES := $(PYBIND11_INCLUDES) $(PYTHON_INCLUDE) $(PROJECT_INCLUDE)
1010
ALL_LIBS := $(PYTHON_LIB)
1111

12-
SOURCES := src/IndexBase.cpp src/IVFFlatIndex.cpp python/zenann_pybind.cpp
12+
SOURCES := src/IndexBase.cpp src/IVFFlatIndex.cpp src/KDTreeIndex.cpp python/zenann_pybind.cpp
1313
EXT_SUFFIX := $(shell python3-config --extension-suffix)
1414
TARGET := build/zenann$(EXT_SUFFIX)
1515

python/zenann_pybind.cpp

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "IndexBase.h"
55
#include "IVFFlatIndex.h"
6+
#include "KDTreeIndex.h"
67

78
namespace py = pybind11;
89
using namespace zenann;
@@ -11,9 +12,7 @@ using namespace zenann;
1112
struct PyIndexBase : IndexBase {
1213
using IndexBase::IndexBase;
1314
void train() override { }
14-
SearchResult search(const Vector&, size_t) const override {
15-
return {};
16-
}
15+
SearchResult search(const Vector&, size_t) const override { return {}; }
1716
};
1817

1918
PYBIND11_MODULE(zenann, m) {
@@ -24,7 +23,7 @@ PYBIND11_MODULE(zenann, m) {
2423
.def_readonly("indices", &SearchResult::indices)
2524
.def_readonly("distances", &SearchResult::distances);
2625

27-
// Bind IndexBase with shared_ptr holder
26+
// Bind IndexBase
2827
py::class_<IndexBase, PyIndexBase, std::shared_ptr<IndexBase>>(m, "IndexBase")
2928
.def(py::init<size_t>(), py::arg("dim"))
3029
.def("build", &IndexBase::build, py::arg("data"),
@@ -34,9 +33,9 @@ PYBIND11_MODULE(zenann, m) {
3433
.def("search", &IndexBase::search, py::arg("query"), py::arg("k"),
3534
"Search k nearest neighbors")
3635
.def_property_readonly("dimension", &IndexBase::dimension,
37-
"Dimension of vectors in the index");
36+
"Dimension of vectors");
3837

39-
// Bind IVFFlatIndex as a subclass of IndexBase
38+
// Bind IVFFlatIndex
4039
py::class_<IVFFlatIndex, IndexBase, std::shared_ptr<IVFFlatIndex>>(m, "IVFFlatIndex")
4140
.def(py::init<size_t, size_t, size_t>(),
4241
py::arg("dim"), py::arg("nlist"), py::arg("nprobe") = 1)
@@ -45,7 +44,25 @@ PYBIND11_MODULE(zenann, m) {
4544
.def("train", &IVFFlatIndex::train,
4645
"Train the IVF index (run K-means and build inverted lists)")
4746
.def("search", &IVFFlatIndex::search, py::arg("query"), py::arg("k"),
48-
"Search top-k nearest neighbors using IVF index")
49-
.def("search_batch", &IVFFlatIndex::search_batch, py::arg("queries"), py::arg("k"),
50-
"Search top-k for a batch of queries");
47+
"Search top-k nearest neighbors")
48+
.def("search_batch", &IVFFlatIndex::search_batch, py::arg("queries"), py::arg("k"),
49+
"Search top-k for a batch of queries")
50+
// Persistence API
51+
.def("write_index", &IVFFlatIndex::write_index, py::arg("filename"),
52+
"Serialize the index to a file")
53+
.def_static("read_index", &IVFFlatIndex::read_index, py::arg("filename"),
54+
"Load an index from a file");
55+
56+
py::class_<KDTreeIndex, IndexBase, std::shared_ptr<KDTreeIndex>>(m, "KDTreeIndex")
57+
.def(py::init<size_t>(), py::arg("dim"))
58+
.def("build", &KDTreeIndex::build, py::arg("data"),
59+
"Add data and build KDTree index")
60+
.def("train", &KDTreeIndex::train,
61+
"Rebuild KDTree index")
62+
.def("search", &KDTreeIndex::search, py::arg("query"), py::arg("k"),
63+
"Exact k-NN search using KDTree")
64+
.def("write_index", &KDTreeIndex::write_index, py::arg("filename"),
65+
"Serialize KDTree index to file")
66+
.def_static("read_index", &KDTreeIndex::read_index, py::arg("filename"),
67+
"Load KDTree index from file");
5168
}

0 commit comments

Comments
 (0)