Skip to content

Commit 8223171

Browse files
authored
Use std::vector as opaque type to avoid copy (#65)
* Use array_to_matrix to clean up cruft * Use std::vector as opaque type to avoid copy Cleanup * Don't use span, unnecessary * Add py buffer support
1 parent 7bbb42f commit 8223171

File tree

4 files changed

+57
-18
lines changed

4 files changed

+57
-18
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,11 @@ def query(
5757
"""
5858
# TODO:
5959
# - typecheck targets
60-
# - don't copy the array
6160
# - add all the options and query strategies
6261

6362
assert targets.dtype == np.float32
6463

65-
# TODO: make Matrix constructor from py::array. This is ugly (and copies).
66-
# Create a Matrix from the input targets
67-
targets_m = ColMajorMatrix_f32(*targets.shape)
68-
targets_m_a = np.array(targets_m, copy=False)
69-
targets_m_a[:] = targets
64+
targets_m = array_to_matrix(targets)
7065

7166
r = query_vq(self._db, targets_m, k, nqueries, nthreads)
7267
return np.array(r)
@@ -116,10 +111,7 @@ def query(self, targets: np.ndarray, k=10, nqueries=10, nthreads=8, nprobe=1):
116111
"""
117112
assert targets.dtype == np.float32
118113

119-
# TODO: use Matrix constructor from py::array
120-
targets_m = ColMajorMatrix_f32(*targets.shape)
121-
targets_m_a = np.array(targets_m, copy=False)
122-
targets_m_a[:] = targets
114+
targets_m = array_to_matrix(targets)
123115

124116
r = query_kmeans(
125117
self._db.dtype,

apis/python/src/tiledb/vector_search/module.cc

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ using Ctx = tiledb::Context;
1414
bool global_debug = true;
1515
double global_time_of_interest;
1616

17+
PYBIND11_MAKE_OPAQUE(std::vector<uint32_t>);
18+
PYBIND11_MAKE_OPAQUE(std::vector<uint64_t>);
19+
1720
namespace {
1821

1922

@@ -91,18 +94,19 @@ static void declare_pyarray_to_matrix(py::module& m, const std::string& suffix)
9194
});
9295
}
9396

94-
template <typename T>
97+
template <typename T, typename Id_Type = uint64_t>
9598
static void declare_kmeans_query(py::module& m, const std::string& suffix) {
9699
m.def(("kmeans_query_" + suffix).c_str(),
97100
[](const ColMajorMatrix<T>& parts,
98101
const ColMajorMatrix<float>& centroids,
99102
const ColMajorMatrix<float>& query_vectors,
100-
std::vector<uint64_t>& indices,
101-
std::vector<uint64_t>& ids,
103+
std::vector<Id_Type> indices,
104+
std::vector<Id_Type> ids,
102105
size_t nprobe,
103106
size_t k_nn,
104107
bool nth,
105108
size_t nthreads) -> ColMajorMatrix<size_t> { // TODO change return type
109+
106110
auto r = detail::ivf::qv_query_heap_infinite_ram(
107111
parts,
108112
centroids,
@@ -132,8 +136,29 @@ static void declareColMajorMatrixSubclass(py::module& mod,
132136
cls.def(py::init<const Ctx&, std::string, size_t>(), py::keep_alive<1,2>());
133137
}
134138

139+
template <typename T>
140+
void declareStdVector(py::module& m) {
141+
142+
auto name = std::string("IntVector") + typeid(T).name();
143+
py::class_<std::vector<T>>(m, name.c_str(), py::buffer_protocol())
144+
.def(py::init<>())
145+
.def("clear", &std::vector<T>::clear)
146+
.def("pop_back", &std::vector<T>::pop_back)
147+
.def("__len__", [](const std::vector<T> &v) { return v.size(); })
148+
.def_buffer([](std::vector<T> &v) -> py::buffer_info {
149+
return py::buffer_info(
150+
v.data(), /* Pointer to buffer */
151+
sizeof(T), /* Size of one scalar */
152+
py::format_descriptor<T>::format(), /* Python struct-style format descriptor */
153+
1, /* Number of dimensions */
154+
{ v.size() }, /* Buffer dimensions */
155+
{ sizeof(T) });
156+
});
135157
}
136158

159+
} // anonymous namespace
160+
161+
137162
PYBIND11_MODULE(_tiledbvspy, m) {
138163

139164
py::class_<tiledb::Context> (m, "Ctx", py::module_local())
@@ -150,14 +175,19 @@ PYBIND11_MODULE(_tiledbvspy, m) {
150175

151176
/* === Vector === */
152177

153-
declareVector<uint32_t>(m, "_u32");
154-
declareVector<uint64_t>(m, "_u64");
155-
declareVector<float>(m, "_f32");
156-
declareVector<double>(m, "_f64");
178+
// Must have matching PYBIND11_MAKE_OPAQUE declaration at top of file
179+
declareStdVector<uint32_t>(m);
180+
declareStdVector<uint64_t>(m);
157181

158182
m.def("read_vector_u32", &read_vector<uint32_t>, "Read a vector from TileDB");
159183
m.def("read_vector_u64", &read_vector<uint64_t>, "Read a vector from TileDB");
160184

185+
m.def("_create_vector_u64", []() {
186+
auto v = std::vector<uint64_t>(10);
187+
// fill vector with range 1:10 using std::iota
188+
std::iota(v.begin(), v.begin() + 10, 0);
189+
return v;
190+
});
161191

162192
/* === Matrix === */
163193

apis/python/test/test_api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ def test_load_matrix(tmpdir):
2424
assert np.array_equal(m, data)
2525
assert np.array_equal(orig_matrix[0, 0], data[0, 0])
2626

27+
def test_vector(tmpdir):
28+
v = vspy._create_vector_u64()
29+
assert np.array_equal(np.array(v), np.arange(10))
2730

2831
@pytest.mark.skipif(
2932
not os.path.exists(os.path.expanduser("~/work/proj/vector-search/datasets/sift-andrew/")),
@@ -58,4 +61,4 @@ def test_flat_query():
5861
assert np.array_equal(np.sort(ra[:k], axis=0), np.sort(g[:k, :nqueries], axis=0))
5962

6063
g_m = vs.load_as_matrix(g_uri)
61-
assert vspy.validate_top_k_u64(r, g_m)
64+
assert vspy.validate_top_k_u64(r, g_m)

src/include/detail/linalg/matrix.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,20 @@ std::string matrix_info(const std::vector<T>& A, const std::string& msg = "") {
292292
return str;
293293
}
294294

295+
/**
296+
* Print information about a std::span -- overload.
297+
* @param A
298+
*/
299+
template <class T>
300+
std::string matrix_info(const std::span<T>& A, const std::string& msg = "") {
301+
std::string str = "# " + msg;
302+
if (!msg.empty()) {
303+
str += ": ";
304+
}
305+
str += "Shape: (" + std::to_string(A.size()) + " )";
306+
return str;
307+
}
308+
295309
template <class Matrix>
296310
void debug_matrix(const Matrix& A, const std::string& msg = "") {
297311
if (global_debug) {

0 commit comments

Comments
 (0)