Skip to content

Commit bd36040

Browse files
committed
Test and fix array_to_matrix
1 parent 9b73d50 commit bd36040

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

apis/python/src/tiledb/vector_search/module.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,15 @@ static void declareColMajorMatrix(py::module& mod, std::string const& suffix) {
7777
template <typename T>
7878
static void declare_pyarray_to_matrix(py::module& m, const std::string& suffix) {
7979
m.def(("pyarray_copyto_matrix" + suffix).c_str(),
80-
[](py::array_t<T> arr) -> ColMajorMatrix<T> {
80+
[](py::array_t<T, py::array::f_style> arr) -> ColMajorMatrix<T> {
8181
py::buffer_info info = arr.request();
8282
if (info.ndim != 2)
8383
throw std::runtime_error("Number of dimensions must be two");
8484
if (info.format != py::format_descriptor<T>::format())
8585
throw std::runtime_error("Mismatched buffer format!");
8686

8787
auto data = std::unique_ptr<T[]>{new T[info.shape[0] * info.shape[1]]};
88+
std::memcpy(data.get(), info.ptr, info.shape[0] * info.shape[1] * sizeof(T));
8889
auto r = ColMajorMatrix<T>(std::move(data), info.shape[0], info.shape[1]);
8990
return r;
9091
});

apis/python/test/test_module.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
from common import *
33

4+
import tiledb.vector_search as vs
45
from tiledb.vector_search import _tiledbvspy as vspy
56

67

@@ -19,10 +20,20 @@ def test_tdbMatrix(tmpdir):
1920
m_array2 = np.array(m, copy=False) # mutable view
2021
v = np.random.rand(1).astype(np.float32)
2122
m_array2[1, 2] = v
23+
2224
data[1, 2] = v
25+
2326
assert np.array_equal(m_array2, data)
2427
assert m[1, 2] == v
2528

29+
def test_array_to_matrix(tmpdir):
30+
p = str(tmpdir.mkdir("test").join("test.tdb"))
31+
32+
data = np.random.rand(12).astype(np.float32).reshape(3, 4)
33+
34+
mat = vs.array_to_matrix(data)
35+
mat_view = np.array(mat, copy=True) # mutable view
36+
assert np.array_equal(mat_view, data)
2637

2738
def test_context(tmpdir):
2839
str(tmpdir.mkdir("test").join("test.tdb"))

0 commit comments

Comments
 (0)