Skip to content

unified interface for all index types  #66

@BlaiseMuhirwa

Description

@BlaiseMuhirwa

We should use a type-erased interface, such as the following (perhaps with reduced virtualization).

/**
 * Type-erased interface for the index class. 
 */
class PyIndexBase : public std::enable_shared_from_this<PyIndexBase> {
protected:
  typedef std::pair<py::array_t<float>, py::array_t<int>> DistancesLabelsPair;

public:
  virtual ~PyIndexBase() = default;

  virtual void add(const py::array &data, int ef_construction,
                   int num_initializations, py::object labels = py::none()) = 0;
  virtual std::shared_ptr<PyIndexBase> allocateNodes(
      const py::array_t<float, py::array::c_style | py::array::forcecast>
          &data) = 0;
  virtual DistancesLabelsPair searchSingle(const py::array &query, int K,
                                           int ef_search,
                                           int num_initializations) = 0;

  virtual DistancesLabelsPair search(const py::array &queries, int K,
                                     int ef_search,
                                     int num_initializations) = 0;

  virtual void save(const std::string &filename) = 0;

  virtual void buildGraphLinks(const std::string &mtx_filename) = 0;

  virtual std::vector<std::vector<uint32_t>> getGraphOutdegreeTable() = 0;

  virtual void reorder(const std::vector<std::string> &strategies) = 0;

  virtual void setNumThreads(uint32_t num_threads) = 0;

  virtual uint32_t getNumThreads() = 0;

  virtual uint64_t getQueryDistanceComputations() const = 0;
};

Then, in the bindings, we can do something like this

struct BindAllDataTypes {
  explicit BindAllDataTypes(
      py::class_<PyIndexBase, std::shared_ptr<PyIndexBase>> &index_class)
      : _index_class(index_class) {}

  template <DataType data_type> void operator()() {
    bindSpecialization<SquaredL2Distance<data_type>, int>(_index_class);
    bindSpecialization<InnerProductDistance<data_type>, int>(_index_class);
  }
  py::class_<PyIndexBase, std::shared_ptr<PyIndexBase>> &_index_class;
};

BindAllDataTypes binder(index_class);
flatnav::util::for_each_data_type<
    BindAllDataTypes, DataType::float32, DataType::int8,
    DataType::uint8>::apply(std::forward<BindAllDataTypes>(binder));

Currently, this doesn't quite work because pybind will bind methods for only the first type (in this case, DataType::float32). When it gets to the next type, DataType::int8, it will notice that the index class already has methods bound, such as add, search, etc. We need to find a way to do something along these lines.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions