Skip to content

Commit d076f30

Browse files
committed
feat: pybind wrapper for HNSW
1 parent 4eb5822 commit d076f30

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

python/zenann_pybind.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "IndexBase.h"
55
#include "IVFFlatIndex.h"
66
#include "KDTreeIndex.h"
7+
#include "HNSWIndex.h"
78

89
namespace py = pybind11;
910
using namespace zenann;
@@ -61,7 +62,6 @@ PYBIND11_MODULE(zenann, m) {
6162
&IVFFlatIndex::search_batch,
6263
py::const_),
6364
py::arg("queries"), py::arg("k"), py::arg("nprobe"))
64-
6565
.def("write_index", &IVFFlatIndex::write_index, py::arg("filename"))
6666
.def_static("read_index", &IVFFlatIndex::read_index, py::arg("filename"));
6767

@@ -77,4 +77,32 @@ PYBIND11_MODULE(zenann, m) {
7777
"Serialize KDTree index to file")
7878
.def_static("read_index", &KDTreeIndex::read_index, py::arg("filename"),
7979
"Load KDTree index from file");
80+
81+
py::class_<HNSWIndex, IndexBase, std::shared_ptr<HNSWIndex>>(m, "HNSWIndex")
82+
.def(py::init<size_t,size_t,size_t>(),
83+
py::arg("dim"), py::arg("M"), py::arg("efConstruction")=200)
84+
.def("build", &HNSWIndex::build, py::arg("data"))
85+
.def("train", &HNSWIndex::train)
86+
.def("search", (SearchResult (HNSWIndex::*)(const Vector&,size_t) const)
87+
&HNSWIndex::search, py::arg("query"), py::arg("k"))
88+
.def("search", (SearchResult (HNSWIndex::*)(const Vector&,size_t,size_t) const)
89+
&HNSWIndex::search,
90+
py::arg("query"), py::arg("k"), py::arg("efSearch"))
91+
.def("search_batch", (std::vector<SearchResult> (HNSWIndex::*)(const Dataset&,size_t) const)
92+
&HNSWIndex::search_batch,
93+
py::arg("queries"), py::arg("k"))
94+
.def("search_batch", (std::vector<SearchResult> (HNSWIndex::*)(const Dataset&,size_t,size_t) const)
95+
&HNSWIndex::search_batch,
96+
py::arg("queries"), py::arg("k"), py::arg("efSearch"))
97+
.def("set_ef_search", &HNSWIndex::set_ef_search, py::arg("efSearch"))
98+
.def("reorder_layout", (void (HNSWIndex::*)()) &HNSWIndex::reorder_layout)
99+
.def("reorder_layout", (void (HNSWIndex::*)(const std::string&))
100+
&HNSWIndex::reorder_layout, py::arg("mapping_file"))
101+
.def("write_index", &HNSWIndex::write_index, py::arg("filename"))
102+
.def_static("read_index", &HNSWIndex::read_index, py::arg("filename"))
103+
.def_static("compute_recall_with_mapping",
104+
&HNSWIndex::compute_recall_with_mapping,
105+
py::arg("groundtruth"), py::arg("predicted_flat"),
106+
py::arg("nq"), py::arg("k"), py::arg("mapping_file"))
107+
;
80108
}

0 commit comments

Comments
 (0)