Skip to content

Commit 82ca0a2

Browse files
authored
Use preloaded matrices as input for kmeans query call (#47)
1 parent ee73170 commit 82ca0a2

File tree

3 files changed

+32
-35
lines changed

3 files changed

+32
-35
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ def __init__(self, uri, dtype: np.dtype):
9292
self.dtype = dtype
9393

9494
ctx = Ctx({}) # TODO pass in a context
95-
# TODO self._db = load_as_matrix(self.db_uri)
95+
self._db = load_as_matrix(self.parts_db_uri)
9696
self._centroids = load_as_matrix(self.centroids_uri)
9797
self._index = read_vector_u64(ctx, self.index_uri)
98-
# self._ids = load_as_matrix(self.ids_uri)
98+
self._ids = read_vector_u64(ctx, self.ids_uri)
9999

100100
def query(self, targets: np.ndarray, k=10, nqueries=10, nthreads=8, nprobe=1):
101101
"""
@@ -122,12 +122,12 @@ def query(self, targets: np.ndarray, k=10, nqueries=10, nthreads=8, nprobe=1):
122122
targets_m_a[:] = targets
123123

124124
r = query_kmeans(
125-
self.dtype,
126-
self.parts_db_uri,
125+
self._db.dtype,
126+
self._db,
127127
self._centroids,
128128
targets_m,
129129
self._index,
130-
self.ids_uri,
130+
self._ids,
131131
nprobe=nprobe,
132132
k_nn=k,
133133
nth=True, # ??

apis/python/src/tiledb/vector_search/module.cc

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -74,26 +74,40 @@ static void declareColMajorMatrix(py::module& mod, std::string const& suffix) {
7474

7575
}
7676

77+
template <typename T>
78+
static void declare_pyarray_to_matrix(py::module& m, const std::string& suffix) {
79+
m.def(("pyarray_copyto_matrix" + suffix).c_str(),
80+
[](py::array_t<T> arr) -> ColMajorMatrix<T> {
81+
py::buffer_info info = arr.request();
82+
if (info.ndim != 2)
83+
throw std::runtime_error("Number of dimensions must be two");
84+
if (info.format != py::format_descriptor<T>::format())
85+
throw std::runtime_error("Mismatched buffer format!");
86+
87+
auto data = std::unique_ptr<T[]>{new T[info.shape[0] * info.shape[1]]};
88+
auto r = ColMajorMatrix<T>(std::move(data), info.shape[0], info.shape[1]);
89+
return r;
90+
});
91+
}
92+
7793
template <typename T>
7894
static void declare_kmeans_query(py::module& m, const std::string& suffix) {
7995
m.def(("kmeans_query_" + suffix).c_str(),
80-
[](Ctx ctx,
81-
const std::string& part_uri,
96+
[](const ColMajorMatrix<T>& parts,
8297
const ColMajorMatrix<float>& centroids,
8398
const ColMajorMatrix<float>& query_vectors,
8499
std::vector<uint64_t>& indices,
85-
const std::string& id_uri,
100+
std::vector<shuffled_ids_type>& ids,
86101
size_t nprobe,
87102
size_t k_nn,
88103
bool nth,
89-
size_t nthreads) -> ColMajorMatrix<size_t> {
104+
size_t nthreads) -> ColMajorMatrix<size_t> { // TODO change return type
90105
auto r = detail::ivf::qv_query_heap_infinite_ram<T>(
91-
ctx,
92-
part_uri,
106+
parts,
93107
centroids,
94108
query_vectors,
95109
indices,
96-
id_uri,
110+
ids,
97111
nprobe,
98112
k_nn,
99113
nth,
@@ -102,22 +116,6 @@ static void declare_kmeans_query(py::module& m, const std::string& suffix) {
102116
}, py::keep_alive<1,2>());
103117
}
104118

105-
template <typename T>
106-
static void declare_pyarray_to_matrix(py::module& m, const std::string& suffix) {
107-
m.def(("pyarray_copyto_matrix" + suffix).c_str(),
108-
[](py::array_t<T> arr) -> ColMajorMatrix<T> {
109-
py::buffer_info info = arr.request();
110-
if (info.ndim != 2)
111-
throw std::runtime_error("Number of dimensions must be two");
112-
if (info.format != py::format_descriptor<T>::format())
113-
throw std::runtime_error("Mismatched buffer format!");
114-
115-
auto data = std::unique_ptr<T[]>{new T[info.shape[0] * info.shape[1]]};
116-
auto r = ColMajorMatrix<T>(std::move(data), info.shape[0], info.shape[1]);
117-
return r;
118-
});
119-
}
120-
121119

122120
// Declarations for typed subclasses of ColMajorMatrix
123121
template <typename P>

apis/python/src/tiledb/vector_search/module.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,11 @@ def query_vq(db: "colMajorMatrix", *args):
8484

8585
def query_kmeans(
8686
dtype: np.dtype,
87-
parts_uri: str,
87+
parts_db: "colMajorMatrix",
8888
centroids_db: "colMajorMatrix",
8989
query_vectors: "colMajorMatrix",
90-
index_db: "Vector",
91-
ids_uri: str,
90+
indices: "Vector",
91+
ids: "Vector",
9292
nprobe: int,
9393
k_nn: int,
9494
nth: bool,
@@ -128,12 +128,11 @@ def query_kmeans(
128128

129129
args = tuple(
130130
[
131-
ctx,
132-
parts_uri,
131+
parts_db,
133132
centroids_db,
134133
query_vectors,
135-
index_db,
136-
ids_uri,
134+
indices,
135+
ids,
137136
nprobe,
138137
k_nn,
139138
nth,

0 commit comments

Comments
 (0)