99#include < assert.h>
1010
1111namespace py = pybind11;
12+ using namespace pybind11 ::literals; // needed to bring in _a literal
1213
1314/*
1415 * replacement for the openmp '#pragma omp parallel for' directive
@@ -73,6 +74,12 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn
7374
7475}
7576
77+ inline void assert_true (bool expr, const std::string & msg) {
78+ if (expr == false )
79+ throw std::runtime_error (" Unpickle Error: " +msg);
80+ return ;
81+ }
82+
7683
7784
7885template <typename dist_t , typename data_t =float >
@@ -98,7 +105,7 @@ class Index {
98105
99106 default_ef=10 ;
100107 }
101-
108+
102109 static const int ser_version = 1 ; // serialization version
103110
104111 std::string space_name;
@@ -278,15 +285,11 @@ class Index {
278285 return ids;
279286 }
280287
281- inline void assert_true (bool expr, const std::string & msg) {
282- if (expr == false )
283- throw std::runtime_error (" assert failed: " +msg);
284- return ;
285- }
288+
289+ py::dict getAnnData () const { /* WARNING: Index::getAnnData is not thread-safe with Index::addItems */
290+
286291
287292
288- py::tuple getAnnData () const { /* WARNING: Index::getAnnData is not thread-safe with Index::addItems */
289-
290293 std::unique_lock <std::mutex> templock (appr_alg->global );
291294
292295 unsigned int level0_npy_size = appr_alg->cur_element_count * appr_alg->size_data_per_element_ ;
@@ -345,140 +348,153 @@ class Index {
345348 delete[] f;
346349 });
347350
348- return py::make_tuple (appr_alg->offsetLevel0_ ,
349- appr_alg->max_elements_ ,
350- appr_alg->cur_element_count ,
351- appr_alg->size_data_per_element_ ,
352- appr_alg->label_offset_ ,
353- appr_alg->offsetData_ ,
354- appr_alg->maxlevel_ ,
355- appr_alg->enterpoint_node_ ,
356- appr_alg->maxM_ ,
357- appr_alg->maxM0_ ,
358- appr_alg->M_ ,
359- appr_alg->mult_ ,
360- appr_alg->ef_construction_ ,
361- appr_alg->ef_ ,
362- appr_alg->has_deletions_ ,
363- appr_alg->size_links_per_element_ ,
364- py::array_t <hnswlib::labeltype>(
365- {appr_alg->label_lookup_ .size ()}, // shape
366- {sizeof (hnswlib::labeltype)}, // C-style contiguous strides for double
367- label_lookup_key_npy, // the data pointer
368- free_when_done_lb),
369- py::array_t <hnswlib::tableint>(
370- {appr_alg->label_lookup_ .size ()}, // shape
371- {sizeof (hnswlib::tableint)}, // C-style contiguous strides for double
372- label_lookup_val_npy, // the data pointer
373- free_when_done_id),
374- py::array_t <int >(
375- {appr_alg->element_levels_ .size ()}, // shape
376- {sizeof (int )}, // C-style contiguous strides for double
377- element_levels_npy, // the data pointer
378- free_when_done_lvl),
379- py::array_t <char >(
380- {level0_npy_size}, // shape
381- {sizeof (char )}, // C-style contiguous strides for double
382- data_level0_npy, // the data pointer
383- free_when_done_l0),
384- py::array_t <char >(
385- {link_npy_size}, // shape
386- {sizeof (char )}, // C-style contiguous strides for double
387- link_list_npy, // the data pointer
388- free_when_done_ll)
389- );
351+ /* TODO: serialize state of random generators appr_alg->level_generator_ and appr_alg->update_probability_generator_ */
352+ /* for full reproducibility / to avoid re-initializing generators inside Index::createFromParams */
353+
354+ return py::dict (
355+ " offset_level0" _a=appr_alg->offsetLevel0_ ,
356+ " max_elements" _a=appr_alg->max_elements_ ,
357+ " cur_element_count" _a=appr_alg->cur_element_count ,
358+ " size_data_per_element" _a=appr_alg->size_data_per_element_ ,
359+ " label_offset" _a=appr_alg->label_offset_ ,
360+ " offset_data" _a=appr_alg->offsetData_ ,
361+ " max_level" _a=appr_alg->maxlevel_ ,
362+ " enterpoint_node" _a=appr_alg->enterpoint_node_ ,
363+ " max_M" _a=appr_alg->maxM_ ,
364+ " max_M0" _a=appr_alg->maxM0_ ,
365+ " M" _a=appr_alg->M_ ,
366+ " mult" _a=appr_alg->mult_ ,
367+ " ef_construction" _a=appr_alg->ef_construction_ ,
368+ " ef" _a=appr_alg->ef_ ,
369+ " has_deletions" _a=appr_alg->has_deletions_ ,
370+ " size_links_per_element" _a=appr_alg->size_links_per_element_ ,
371+
372+ " label_lookup_external" _a=py::array_t <hnswlib::labeltype>(
373+ {appr_alg->label_lookup_ .size ()}, // shape
374+ {sizeof (hnswlib::labeltype)}, // C-style contiguous strides for double
375+ label_lookup_key_npy, // the data pointer
376+ free_when_done_lb),
377+
378+ " label_lookup_internal" _a=py::array_t <hnswlib::tableint>(
379+ {appr_alg->label_lookup_ .size ()}, // shape
380+ {sizeof (hnswlib::tableint)}, // C-style contiguous strides for double
381+ label_lookup_val_npy, // the data pointer
382+ free_when_done_id),
383+
384+ " element_levels" _a=py::array_t <int >(
385+ {appr_alg->element_levels_ .size ()}, // shape
386+ {sizeof (int )}, // C-style contiguous strides for double
387+ element_levels_npy, // the data pointer
388+ free_when_done_lvl),
389+
390+ // linkLists_,element_levels_,data_level0_memory_
391+ " data_level0" _a=py::array_t <char >(
392+ {level0_npy_size}, // shape
393+ {sizeof (char )}, // C-style contiguous strides for double
394+ data_level0_npy, // the data pointer
395+ free_when_done_l0),
396+
397+ " link_lists" _a=py::array_t <char >(
398+ {link_npy_size}, // shape
399+ {sizeof (char )}, // C-style contiguous strides for double
400+ link_list_npy, // the data pointer
401+ free_when_done_ll)
402+
403+ );
404+
390405
391406 }
392407
393408
394- py::tuple getIndexParams () const {
395- /* TODO: serialize state of random generators appr_alg->level_generator_ and appr_alg->update_probability_generator_ */
396- /* for full reproducibility / to avoid re-initializing generators inside Index::createFromParams */
397-
398- return py::make_tuple (py::int_ (Index<float >::ser_version), // serialization version
399-
400- /* TODO: convert the following two py::tuple's to py::dict */
401- py::make_tuple (space_name, dim, index_inited, ep_added, normalize, num_threads_default, seed, default_ef),
402- index_inited == true ? getAnnData () : py::make_tuple ()); /* WARNING: Index::getAnnData is not thread-safe with Index::addItems */
403-
404-
409+ py::dict getIndexParams () const { /* WARNING: Index::getAnnData is not thread-safe with Index::addItems */
410+ auto params = py::dict (
411+ " ser_version" _a=py::int_ (Index<float >::ser_version), // serialization version
412+ " space" _a=space_name,
413+ " dim" _a=dim,
414+ " index_inited" _a=index_inited,
415+ " ep_added" _a=ep_added,
416+ " normalize" _a=normalize,
417+ " num_threads" _a=num_threads_default,
418+ " seed" _a=seed
419+ );
420+
421+ if (index_inited == false )
422+ return py::dict ( **params, " ef" _a=default_ef);
405423
424+ auto ann_params = getAnnData ();
425+
426+ return py::dict (**params, **ann_params);
406427 }
407428
408429
409- static Index<float > * createFromParams (const py::tuple t) {
410-
411- if (py::int_ (Index<float >::ser_version) != t[0 ].cast <int >()) // check serialization version
412- throw std::runtime_error (" Serialization version mismatch!" );
430+ static Index<float > * createFromParams (const py::dict d) {
413431
414- py::tuple index_params=t[ 1 ]. cast <py::tuple>(); /* TODO: convert index_params from py::tuple to py::dict */
415- py::tuple ann_params=t[ 2 ].cast <py::tuple >(); /* TODO: convert ann_params from py::tuple to py::dict */
432+ // check serialization version
433+ assert_true ((( int ) py::int_ (Index< float >::ser_version)) >= d[ " ser_version " ].cast <int >(), " Invalid serialization version! " );
416434
417- auto space_name_=index_params[ 0 ].cast <std::string>();
418- auto dim_=index_params[ 1 ].cast <int >();
419- auto index_inited_=index_params[ 2 ].cast <bool >();
435+ auto space_name_=d[ " space " ].cast <std::string>();
436+ auto dim_=d[ " dim " ].cast <int >();
437+ auto index_inited_=d[ " index_inited " ].cast <bool >();
420438
421- Index<float > *new_index = new Index<float >(index_params[ 0 ]. cast <std::string>(), index_params[ 1 ]. cast < int >() );
439+ Index<float > *new_index = new Index<float >(space_name_, dim_ );
422440
423441 /* TODO: deserialize state of random generators into new_index->level_generator_ and new_index->update_probability_generator_ */
424442 /* for full reproducibility / state of generators is serialized inside Index::getIndexParams */
425- new_index->seed = index_params[ 6 ].cast <size_t >();
443+ new_index->seed = d[ " seed " ].cast <size_t >();
426444
427445 if (index_inited_){
428- new_index->appr_alg = new hnswlib::HierarchicalNSW<dist_t >(new_index->l2space , ann_params[ 1 ].cast <size_t >(), ann_params[ 10 ].cast <size_t >(), ann_params[ 12 ].cast <size_t >(), new_index->seed );
429- new_index->cur_l = ann_params[ 2 ].cast <size_t >();
446+ new_index->appr_alg = new hnswlib::HierarchicalNSW<dist_t >(new_index->l2space , d[ " max_elements " ].cast <size_t >(), d[ " M " ].cast <size_t >(), d[ " ef_construction " ].cast <size_t >(), new_index->seed );
447+ new_index->cur_l = d[ " cur_element_count " ].cast <size_t >();
430448 }
431449
432450 new_index->index_inited = index_inited_;
433- new_index->ep_added =index_params[ 3 ].cast <bool >();
434- new_index->num_threads_default =index_params[ 5 ].cast <int >();
435- new_index->default_ef =index_params[ 7 ].cast <size_t >();
451+ new_index->ep_added =d[ " ep_added " ].cast <bool >();
452+ new_index->num_threads_default =d[ " num_threads " ].cast <int >();
453+ new_index->default_ef =d[ " ef " ].cast <size_t >();
436454
437455 if (index_inited_)
438- new_index->setAnnData (ann_params);
439-
456+ new_index->setAnnData (d);
440457
441458 return new_index;
442459 }
443460
444461 static Index<float > * createFromIndex (const Index<float > & index) {
445- /* WARNING: Index::getIndexParams is not thread-safe with Index::addItems */
446- return createFromParams (index.getIndexParams ());
462+ return createFromParams (index.getIndexParams ());
447463 }
448464
449-
450- void setAnnData (const py::tuple t) {
451- /* WARNING: Index::setAnnData is not thread-safe with Index::addItems */
452-
465+ void setAnnData (const py::dict d) { /* WARNING: Index::setAnnData is not thread-safe with Index::addItems */
466+
467+
453468 std::unique_lock <std::mutex> templock (appr_alg->global );
454469
455- assert_true (appr_alg->offsetLevel0_ == t[0 ].cast <size_t >(), " Invalid value of offsetLevel0_ " );
456- assert_true (appr_alg->max_elements_ == t[1 ].cast <size_t >(), " Invalid value of max_elements_ " );
470+ assert_true (appr_alg->offsetLevel0_ == d[" offset_level0" ].cast <size_t >(), " Invalid value of offsetLevel0_ " );
471+ assert_true (appr_alg->max_elements_ == d[" max_elements" ].cast <size_t >(), " Invalid value of max_elements_ " );
472+
473+ appr_alg->cur_element_count = d[" cur_element_count" ].cast <size_t >();
457474
458- appr_alg->cur_element_count = t[2 ].cast <size_t >();
475+ assert_true (appr_alg->size_data_per_element_ == d[" size_data_per_element" ].cast <size_t >(), " Invalid value of size_data_per_element_ " );
476+ assert_true (appr_alg->label_offset_ == d[" label_offset" ].cast <size_t >(), " Invalid value of label_offset_ " );
477+ assert_true (appr_alg->offsetData_ == d[" offset_data" ].cast <size_t >(), " Invalid value of offsetData_ " );
459478
460- assert_true (appr_alg->size_data_per_element_ == t[3 ].cast <size_t >(), " Invalid value of size_data_per_element_ " );
461- assert_true (appr_alg->label_offset_ == t[4 ].cast <size_t >(), " Invalid value of label_offset_ " );
462- assert_true (appr_alg->offsetData_ == t[5 ].cast <size_t >(), " Invalid value of offsetData_ " );
479+ appr_alg->maxlevel_ = d[" max_level" ].cast <int >();
480+ appr_alg->enterpoint_node_ = d[" enterpoint_node" ].cast <hnswlib::tableint>();
463481
464- appr_alg->maxlevel_ = t[6 ].cast <int >();
465- appr_alg->enterpoint_node_ = t[7 ].cast <hnswlib::tableint>();
482+ assert_true (appr_alg->maxM_ == d[" max_M" ].cast <size_t >(), " Invalid value of maxM_ " );
483+ assert_true (appr_alg->maxM0_ == d[" max_M0" ].cast <size_t >(), " Invalid value of maxM0_ " );
484+ assert_true (appr_alg->M_ == d[" M" ].cast <size_t >(), " Invalid value of M_ " );
485+ assert_true (appr_alg->mult_ == d[" mult" ].cast <double >(), " Invalid value of mult_ " );
486+ assert_true (appr_alg->ef_construction_ == d[" ef_construction" ].cast <size_t >(), " Invalid value of ef_construction_ " );
466487
467- assert_true (appr_alg->maxM_ == t[8 ].cast <size_t >(), " Invalid value of maxM_ " );
468- assert_true (appr_alg->maxM0_ == t[9 ].cast <size_t >(), " Invalid value of maxM0_ " );
469- assert_true (appr_alg->M_ == t[10 ].cast <size_t >(), " Invalid value of M_ " );
470- assert_true (appr_alg->mult_ == t[11 ].cast <double >(), " Invalid value of mult_ " );
471- assert_true (appr_alg->ef_construction_ == t[12 ].cast <size_t >(), " Invalid value of ef_construction_ " );
488+ appr_alg->ef_ = d[" ef" ].cast <size_t >();
489+ appr_alg->has_deletions_ =d[" has_deletions" ].cast <bool >();
472490
473- appr_alg->ef_ = t[13 ].cast <size_t >();
474- appr_alg->has_deletions_ =t[14 ].cast <bool >();
475- assert_true (appr_alg->size_links_per_element_ == t[15 ].cast <size_t >(), " Invalid value of size_links_per_element_ " );
491+ assert_true (appr_alg->size_links_per_element_ == d[" size_links_per_element" ].cast <size_t >(), " Invalid value of size_links_per_element_ " );
476492
477- auto label_lookup_key_npy = t[ 16 ].cast <py::array_t < hnswlib::labeltype, py::array::c_style | py::array::forcecast > >();
478- auto label_lookup_val_npy = t[ 17 ].cast <py::array_t < hnswlib::tableint, py::array::c_style | py::array::forcecast > >();
479- auto element_levels_npy = t[ 18 ].cast <py::array_t < int , py::array::c_style | py::array::forcecast > >();
480- auto data_level0_npy = t[ 19 ].cast <py::array_t < char , py::array::c_style | py::array::forcecast > >();
481- auto link_list_npy = t[ 20 ].cast <py::array_t < char , py::array::c_style | py::array::forcecast > >();
493+ auto label_lookup_key_npy = d[ " label_lookup_external " ].cast <py::array_t < hnswlib::labeltype, py::array::c_style | py::array::forcecast > >();
494+ auto label_lookup_val_npy = d[ " label_lookup_internal " ].cast <py::array_t < hnswlib::tableint, py::array::c_style | py::array::forcecast > >();
495+ auto element_levels_npy = d[ " element_levels " ].cast <py::array_t < int , py::array::c_style | py::array::forcecast > >();
496+ auto data_level0_npy = d[ " data_level0 " ].cast <py::array_t < char , py::array::c_style | py::array::forcecast > >();
497+ auto link_list_npy = d[ " link_lists " ].cast <py::array_t < char , py::array::c_style | py::array::forcecast > >();
482498
483499 for (size_t i = 0 ; i < appr_alg->cur_element_count ; i++){
484500 if (label_lookup_val_npy.data ()[i] < 0 ){
@@ -516,7 +532,6 @@ class Index {
516532
517533 }
518534 }
519-
520535}
521536
522537 py::object knnQuery_return_numpy (py::object input, size_t k = 1 , int num_threads = -1 ) {
@@ -640,9 +655,9 @@ PYBIND11_PLUGIN(hnswlib) {
640655 py::module m (" hnswlib" );
641656
642657 py::class_<Index<float >>(m, " Index" )
643- .def (py::init (&Index<float >::createFromParams), py::arg (" params" ))
658+ .def (py::init (&Index<float >::createFromParams), py::arg (" params" ))
644659 /* WARNING: Index::createFromIndex is not thread-safe with Index::addItems */
645- .def (py::init (&Index<float >::createFromIndex), py::arg (" index" ))
660+ .def (py::init (&Index<float >::createFromIndex), py::arg (" index" ))
646661 .def (py::init<const std::string &, const int >(), py::arg (" space" ), py::arg (" dim" ))
647662 .def (" init_index" , &Index<float >::init_new_index, py::arg (" max_elements" ), py::arg (" M" )=16 , py::arg (" ef_construction" )=200 , py::arg (" random_seed" )=100 )
648663 .def (" knn_query" , &Index<float >::knnQuery_return_numpy, py::arg (" data" ), py::arg (" k" )=1 , py::arg (" num_threads" )=-1 )
@@ -682,14 +697,13 @@ PYBIND11_PLUGIN(hnswlib) {
682697
683698 .def (py::pickle (
684699 [](const Index<float > &ind) { // __getstate__
685- /* Return a tuple that fully encodes the state of the object */
686- /* WARNING: Index::getIndexParams is not thread-safe with Index::addItems */
687- return ind.getIndexParams ();
700+ return py::make_tuple (ind.getIndexParams ()); /* Return dict (wrapped in a tuple) that fully encodes state of the Index object */
688701 },
689702 [](py::tuple t) { // __setstate__
690- if (t.size () != 3 )
703+ if (t.size () != 1 )
691704 throw std::runtime_error (" Invalid state!" );
692- return Index<float >::createFromParams (t);
705+
706+ return Index<float >::createFromParams (t[0 ].cast <py::dict>());
693707 }
694708 ))
695709
0 commit comments