33
44#include " IndexBase.h"
55#include " IVFFlatIndex.h"
6+ #include " KDTreeIndex.h"
67
78namespace py = pybind11;
89using namespace zenann ;
@@ -11,9 +12,7 @@ using namespace zenann;
1112struct PyIndexBase : IndexBase {
1213 using IndexBase::IndexBase;
1314 void train () override { }
14- SearchResult search (const Vector&, size_t ) const override {
15- return {};
16- }
15+ SearchResult search (const Vector&, size_t ) const override { return {}; }
1716};
1817
1918PYBIND11_MODULE (zenann, m) {
@@ -24,7 +23,7 @@ PYBIND11_MODULE(zenann, m) {
2423 .def_readonly (" indices" , &SearchResult::indices)
2524 .def_readonly (" distances" , &SearchResult::distances);
2625
27- // Bind IndexBase with shared_ptr holder
26+ // Bind IndexBase
2827 py::class_<IndexBase, PyIndexBase, std::shared_ptr<IndexBase>>(m, " IndexBase" )
2928 .def (py::init<size_t >(), py::arg (" dim" ))
3029 .def (" build" , &IndexBase::build, py::arg (" data" ),
@@ -34,9 +33,9 @@ PYBIND11_MODULE(zenann, m) {
3433 .def (" search" , &IndexBase::search, py::arg (" query" ), py::arg (" k" ),
3534 " Search k nearest neighbors" )
3635 .def_property_readonly (" dimension" , &IndexBase::dimension,
37- " Dimension of vectors in the index " );
36+ " Dimension of vectors" );
3837
39- // Bind IVFFlatIndex as a subclass of IndexBase
38+ // Bind IVFFlatIndex
4039 py::class_<IVFFlatIndex, IndexBase, std::shared_ptr<IVFFlatIndex>>(m, " IVFFlatIndex" )
4140 .def (py::init<size_t , size_t , size_t >(),
4241 py::arg (" dim" ), py::arg (" nlist" ), py::arg (" nprobe" ) = 1 )
@@ -45,7 +44,25 @@ PYBIND11_MODULE(zenann, m) {
4544 .def (" train" , &IVFFlatIndex::train,
4645 " Train the IVF index (run K-means and build inverted lists)" )
4746 .def (" search" , &IVFFlatIndex::search, py::arg (" query" ), py::arg (" k" ),
48- " Search top-k nearest neighbors using IVF index" )
49- .def (" search_batch" , &IVFFlatIndex::search_batch, py::arg (" queries" ), py::arg (" k" ),
50- " Search top-k for a batch of queries" );
47+ " Search top-k nearest neighbors" )
48+ .def (" search_batch" , &IVFFlatIndex::search_batch, py::arg (" queries" ), py::arg (" k" ),
49+ " Search top-k for a batch of queries" )
50+ // Persistence API
51+ .def (" write_index" , &IVFFlatIndex::write_index, py::arg (" filename" ),
52+ " Serialize the index to a file" )
53+ .def_static (" read_index" , &IVFFlatIndex::read_index, py::arg (" filename" ),
54+ " Load an index from a file" );
55+
56+ py::class_<KDTreeIndex, IndexBase, std::shared_ptr<KDTreeIndex>>(m, " KDTreeIndex" )
57+ .def (py::init<size_t >(), py::arg (" dim" ))
58+ .def (" build" , &KDTreeIndex::build, py::arg (" data" ),
59+ " Add data and build KDTree index" )
60+ .def (" train" , &KDTreeIndex::train,
61+ " Rebuild KDTree index" )
62+ .def (" search" , &KDTreeIndex::search, py::arg (" query" ), py::arg (" k" ),
63+ " Exact k-NN search using KDTree" )
64+ .def (" write_index" , &KDTreeIndex::write_index, py::arg (" filename" ),
65+ " Serialize KDTree index to file" )
66+ .def_static (" read_index" , &KDTreeIndex::read_index, py::arg (" filename" ),
67+ " Load KDTree index from file" );
5168}
0 commit comments