Skip to content

Commit e47bd83

Browse files
authored
Add write_ivecs and write_fvecs utils (#189)
1 parent abe48fe commit e47bd83

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

apis/python/src/tiledb/vector_search/utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def _load_vecs_t(uri, dtype, ctx_or_config=None):
1717
elem_nbytes = int(4 + ndim * dtype.itemsize)
1818
if raw.size % elem_nbytes != 0:
1919
raise ValueError(
20-
f"Mismatched dims to bytes in file {uri}: {raw.size}, elem_nbytes"
20+
f"Mismatched dims to bytes in file {uri}: raw.size: {raw.size}, elem_nbytes: {elem_nbytes}"
2121
)
2222
# take a view on the whole array as
2323
# (ndim, sizeof(t)*ndim), and return the actual elements
@@ -40,3 +40,27 @@ def load_fvecs(uri, ctx_or_config=None):
4040

4141
def load_bvecs(uri, ctx_or_config=None):
4242
return _load_vecs_t(uri, np.uint8, ctx_or_config)
43+
44+
45+
def _write_vecs_t(uri, data, dtype, ctx_or_config=None):
46+
with tiledb.scope_ctx(ctx_or_config) as ctx:
47+
dtype = np.dtype(dtype)
48+
vfs = tiledb.VFS(ctx.config())
49+
ndim = data.shape[1]
50+
51+
buffer = io.BytesIO()
52+
53+
for vector in data:
54+
buffer.write(np.array([ndim], dtype=np.int32).tobytes())
55+
buffer.write(vector.tobytes())
56+
57+
with vfs.open(uri, "wb") as f:
58+
f.write(buffer.getvalue())
59+
60+
61+
def write_ivecs(uri, data, ctx_or_config=None):
62+
_write_vecs_t(uri, data, np.int32, ctx_or_config)
63+
64+
65+
def write_fvecs(uri, data, ctx_or_config=None):
66+
_write_vecs_t(uri, data, np.float32, ctx_or_config)

apis/python/test/test_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import os
2+
import numpy as np
3+
from tiledb.vector_search.utils import load_fvecs, load_ivecs, write_fvecs, write_ivecs
4+
5+
def test_load_and_write_vecs(tmp_path):
6+
fvecs_uri = "test/data/siftsmall/siftsmall_base.fvecs"
7+
ivecs_uri = "test/data/siftsmall/siftsmall_groundtruth.ivecs"
8+
9+
fvecs = load_fvecs(fvecs_uri)
10+
assert fvecs.shape == (10000, 128)
11+
assert not np.any(np.isnan(fvecs))
12+
13+
ivecs = load_ivecs(ivecs_uri)
14+
assert ivecs.shape == (100, 100)
15+
assert not np.any(np.isnan(ivecs))
16+
17+
fvecs_uri = os.path.join(tmp_path, "fvecs")
18+
ivecs_uri = os.path.join(tmp_path, "ivecs")
19+
20+
write_fvecs(fvecs_uri, fvecs[:10])
21+
write_ivecs(ivecs_uri, ivecs[:10])
22+
23+
new_fvecs = load_fvecs(fvecs_uri)
24+
assert new_fvecs.shape == (10, 128)
25+
assert not np.any(np.isnan(fvecs))
26+
27+
new_ivecs = load_ivecs(ivecs_uri)
28+
assert new_ivecs.shape == (10, 100)
29+
assert not np.any(np.isnan(ivecs))
30+

0 commit comments

Comments
 (0)