@@ -37,21 +37,33 @@ PYBIND11_MODULE(zenann, m) {
3737
3838 // Bind IVFFlatIndex
3939 py::class_<IVFFlatIndex, IndexBase, std::shared_ptr<IVFFlatIndex>>(m, " IVFFlatIndex" )
40- .def (py::init<size_t , size_t , size_t >(),
41- py::arg (" dim" ), py::arg (" nlist" ), py::arg (" nprobe" ) = 1 )
42- .def (" build" , &IVFFlatIndex::build, py::arg (" data" ),
43- " Add data and train the IVF index" )
44- .def (" train" , &IVFFlatIndex::train,
45- " Train the IVF index (run K-means and build inverted lists)" )
46- .def (" search" , &IVFFlatIndex::search, py::arg (" query" ), py::arg (" k" ),
47- " Search top-k nearest neighbors" )
48- .def (" search_batch" , &IVFFlatIndex::search_batch, py::arg (" queries" ), py::arg (" k" ),
49- " Search top-k for a batch of queries" )
50- // Persistence API
51- .def (" write_index" , &IVFFlatIndex::write_index, py::arg (" filename" ),
52- " Serialize the index to a file" )
53- .def_static (" read_index" , &IVFFlatIndex::read_index, py::arg (" filename" ),
54- " Load an index from a file" );
40+ .def (py::init<size_t , size_t , size_t >(),
41+ py::arg (" dim" ), py::arg (" nlist" ), py::arg (" nprobe" ) = 1 )
42+ .def (" build" , &IVFFlatIndex::build, py::arg (" data" ))
43+ .def (" train" , &IVFFlatIndex::train)
44+ .def (" search" ,
45+ py::overload_cast<const Vector&, size_t >(
46+ &IVFFlatIndex::search,
47+ py::const_),
48+ py::arg (" query" ), py::arg (" k" ))
49+ .def (" search" ,
50+ py::overload_cast<const Vector&, size_t , size_t >(
51+ &IVFFlatIndex::search,
52+ py::const_),
53+ py::arg (" query" ), py::arg (" k" ), py::arg (" nprobe" ))
54+ .def (" search_batch" ,
55+ py::overload_cast<const Dataset&, size_t >(
56+ &IVFFlatIndex::search_batch,
57+ py::const_),
58+ py::arg (" queries" ), py::arg (" k" ))
59+ .def (" search_batch" ,
60+ py::overload_cast<const Dataset&, size_t , size_t >(
61+ &IVFFlatIndex::search_batch,
62+ py::const_),
63+ py::arg (" queries" ), py::arg (" k" ), py::arg (" nprobe" ))
64+
65+ .def (" write_index" , &IVFFlatIndex::write_index, py::arg (" filename" ))
66+ .def_static (" read_index" , &IVFFlatIndex::read_index, py::arg (" filename" ));
5567
5668 py::class_<KDTreeIndex, IndexBase, std::shared_ptr<KDTreeIndex>>(m, " KDTreeIndex" )
5769 .def (py::init<size_t >(), py::arg (" dim" ))
0 commit comments