Skip to content

Commit 2989635

Browse files
author
Nikos Papailiou
committed
Fix format and tests
1 parent 35db4a3 commit 2989635

File tree

5 files changed

+79
-51
lines changed

5 files changed

+79
-51
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
from tiledb.vector_search.module import *
88
from tiledb.cloud.dag import Mode
9+
from typing import Any, Mapping
910

1011
CENTROIDS_ARRAY_NAME = "centroids.tdb"
1112
INDEX_ARRAY_NAME = "index.tdb"
@@ -39,13 +40,17 @@ class FlatIndex(Index):
3940
Optional name of partitions
4041
"""
4142

42-
def __init__(self, uri: str, dtype: np.dtype, parts_name: str = "parts.tdb", ctx: "Ctx" = None):
43+
def __init__(
44+
self,
45+
uri: str,
46+
dtype: np.dtype,
47+
parts_name: str = "parts.tdb",
48+
config: Optional[Mapping[str, Any]] = None,
49+
):
4350
self.uri = uri
4451
self.dtype = dtype
4552
self._index = None
46-
self.ctx = ctx
47-
if ctx is None:
48-
self.ctx = Ctx({})
53+
self.ctx = Ctx(config)
4954

5055
self._db = load_as_matrix(os.path.join(uri, parts_name), ctx=self.ctx)
5156

@@ -106,7 +111,11 @@ class IVFFlatIndex(Index):
106111
"""
107112

108113
def __init__(
109-
self, uri, dtype: np.dtype, memory_budget: int = -1, ctx: "Ctx" = None
114+
self,
115+
uri,
116+
dtype: np.dtype,
117+
memory_budget: int = -1,
118+
config: Optional[Mapping[str, Any]] = None,
110119
):
111120
group = tiledb.Group(uri)
112121
self.parts_db_uri = group[PARTS_ARRAY_NAME].uri
@@ -115,9 +124,7 @@ def __init__(
115124
self.ids_uri = group[IDS_ARRAY_NAME].uri
116125
self.dtype = dtype
117126
self.memory_budget = memory_budget
118-
self.ctx = ctx
119-
if ctx is None:
120-
self.ctx = Ctx({})
127+
self.ctx = Ctx(config)
121128

122129
# TODO pass in a context
123130
if self.memory_budget == -1:

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
from functools import partial
33

44
from tiledb.cloud.dag import Mode
5-
from tiledb.vector_search.index import FlatIndex
6-
from tiledb.vector_search.index import IVFFlatIndex
5+
from tiledb.vector_search.index import FlatIndex, IVFFlatIndex, Index
76

87

98
def ingest(
@@ -23,7 +22,7 @@ def ingest(
2322
verbose: bool = False,
2423
trace_id: Optional[str] = None,
2524
mode: Mode = Mode.LOCAL,
26-
) -> FlatIndex:
25+
) -> Index:
2726
"""
2827
Ingest vectors into TileDB.
2928
@@ -271,9 +270,7 @@ def create_arrays(
271270
)
272271
logger.debug(centroids_schema)
273272
tiledb.Array.create(centroids_uri, centroids_schema)
274-
group.add(
275-
centroids_uri, name=CENTROIDS_ARRAY_NAME
276-
)
273+
group.add(centroids_uri, name=CENTROIDS_ARRAY_NAME)
277274

278275
if not tiledb.array_exists(index_uri):
279276
logger.debug("Creating index array")
@@ -354,7 +351,9 @@ def create_arrays(
354351
except tiledb.TileDBError as err:
355352
message = str(err)
356353
if "already exists" in message:
357-
logger.debug(f"Group '{partial_write_array_dir_uri}' already exists")
354+
logger.debug(
355+
f"Group '{partial_write_array_dir_uri}' already exists"
356+
)
358357
raise err
359358
partial_write_array_group = tiledb.Group(partial_write_array_dir_uri, "w")
360359
group.add(partial_write_array_dir_uri, name=PARTIAL_WRITE_ARRAY_DIR)
@@ -364,10 +363,16 @@ def create_arrays(
364363
except tiledb.TileDBError as err:
365364
message = str(err)
366365
if "already exists" in message:
367-
logger.debug(f"Group '{partial_write_array_index_uri}' already exists")
366+
logger.debug(
367+
f"Group '{partial_write_array_index_uri}' already exists"
368+
)
368369
raise err
369-
partial_write_array_group.add(partial_write_array_index_uri, name=INDEX_ARRAY_NAME)
370-
partial_write_array_index_group = tiledb.Group(partial_write_array_index_uri, "w")
370+
partial_write_array_group.add(
371+
partial_write_array_index_uri, name=INDEX_ARRAY_NAME
372+
)
373+
partial_write_array_index_group = tiledb.Group(
374+
partial_write_array_index_uri, "w"
375+
)
371376

372377
if not tiledb.array_exists(partial_write_array_ids_uri):
373378
logger.debug("Creating temp ids array")
@@ -389,7 +394,9 @@ def create_arrays(
389394
)
390395
logger.debug(ids_schema)
391396
tiledb.Array.create(partial_write_array_ids_uri, ids_schema)
392-
partial_write_array_group.add(partial_write_array_ids_uri, name=IDS_ARRAY_NAME)
397+
partial_write_array_group.add(
398+
partial_write_array_ids_uri, name=IDS_ARRAY_NAME
399+
)
393400

394401
if not tiledb.array_exists(partial_write_array_parts_uri):
395402
logger.debug("Creating temp parts array")
@@ -420,10 +427,12 @@ def create_arrays(
420427
logger.debug(parts_schema)
421428
logger.debug(partial_write_array_parts_uri)
422429
tiledb.Array.create(partial_write_array_parts_uri, parts_schema)
423-
partial_write_array_group.add(partial_write_array_parts_uri, name=PARTS_ARRAY_NAME)
430+
partial_write_array_group.add(
431+
partial_write_array_parts_uri, name=PARTS_ARRAY_NAME
432+
)
424433

425434
for part in range(input_vectors_work_tasks):
426-
part_index_uri = partial_write_array_index_uri+"/"+str(part)
435+
part_index_uri = partial_write_array_index_uri + "/" + str(part)
427436
if not tiledb.array_exists(part_index_uri):
428437
logger.debug(f"Creating part array {part_index_uri}")
429438
index_array_rows_dim = tiledb.Dim(
@@ -867,8 +876,12 @@ def ingest_vectors_udf(
867876
partial_write_array_group = tiledb.Group(partial_write_array_dir_uri)
868877
partial_write_array_ids_uri = partial_write_array_group[IDS_ARRAY_NAME].uri
869878
partial_write_array_parts_uri = partial_write_array_group[PARTS_ARRAY_NAME].uri
870-
partial_write_array_index_dir_uri = partial_write_array_group[INDEX_ARRAY_NAME].uri
871-
partial_write_array_index_group = tiledb.Group(partial_write_array_index_dir_uri)
879+
partial_write_array_index_dir_uri = partial_write_array_group[
880+
INDEX_ARRAY_NAME
881+
].uri
882+
partial_write_array_index_group = tiledb.Group(
883+
partial_write_array_index_dir_uri
884+
)
872885

873886
for part in range(start, end, batch):
874887
part_end = part + batch
@@ -877,7 +890,9 @@ def ingest_vectors_udf(
877890

878891
part_name = str(part) + "-" + str(part_end)
879892

880-
partial_write_array_index_uri = partial_write_array_index_group[str(int(start / batch))].uri
893+
partial_write_array_index_uri = partial_write_array_index_group[
894+
str(int(start / batch))
895+
].uri
881896
logger.debug("Input vectors start_pos: %d, end_pos: %d", part, part_end)
882897
if source_type == "TILEDB_ARRAY":
883898
logger.debug("Start indexing")
@@ -932,8 +947,12 @@ def compute_partition_indexes_udf(
932947
index_array_uri = group[INDEX_ARRAY_NAME].uri
933948
partial_write_array_dir_uri = group[PARTIAL_WRITE_ARRAY_DIR].uri
934949
partial_write_array_group = tiledb.Group(partial_write_array_dir_uri)
935-
partial_write_array_index_dir_uri = partial_write_array_group[INDEX_ARRAY_NAME].uri
936-
partial_write_array_index_group = tiledb.Group(partial_write_array_index_dir_uri)
950+
partial_write_array_index_dir_uri = partial_write_array_group[
951+
INDEX_ARRAY_NAME
952+
].uri
953+
partial_write_array_index_group = tiledb.Group(
954+
partial_write_array_index_dir_uri
955+
)
937956
partition_sizes = np.zeros(partitions)
938957
indexes = np.zeros(partitions + 1).astype(np.uint64)
939958
for part in partial_write_array_index_group:
@@ -978,9 +997,15 @@ def consolidate_partition_udf(
978997
partial_write_array_dir_uri = group[PARTIAL_WRITE_ARRAY_DIR].uri
979998
partial_write_array_group = tiledb.Group(partial_write_array_dir_uri)
980999
partial_write_array_ids_uri = partial_write_array_group[IDS_ARRAY_NAME].uri
981-
partial_write_array_parts_uri = partial_write_array_group[PARTS_ARRAY_NAME].uri
982-
partial_write_array_index_dir_uri = partial_write_array_group[INDEX_ARRAY_NAME].uri
983-
partial_write_array_index_group = tiledb.Group(partial_write_array_index_dir_uri)
1000+
partial_write_array_parts_uri = partial_write_array_group[
1001+
PARTS_ARRAY_NAME
1002+
].uri
1003+
partial_write_array_index_dir_uri = partial_write_array_group[
1004+
INDEX_ARRAY_NAME
1005+
].uri
1006+
partial_write_array_index_group = tiledb.Group(
1007+
partial_write_array_index_dir_uri
1008+
)
9841009
index_array_uri = group[INDEX_ARRAY_NAME].uri
9851010
ids_array_uri = group[IDS_ARRAY_NAME].uri
9861011
parts_array_uri = group[PARTS_ARRAY_NAME].uri
@@ -1342,7 +1367,7 @@ def consolidate_and_vacuum(
13421367
if vfs.is_dir(partial_write_array_dir_uri):
13431368
vfs.remove_dir(partial_write_array_dir_uri)
13441369

1345-
with tiledb.scope_ctx(ctx_or_config=config) as ctx:
1370+
with tiledb.scope_ctx(ctx_or_config=config):
13461371
logger = setup(config, verbose)
13471372
logger.debug("Ingesting Vectors into %r", array_uri)
13481373
try:
@@ -1462,6 +1487,8 @@ def consolidate_and_vacuum(
14621487
consolidate_and_vacuum(array_uri=array_uri, config=config)
14631488

14641489
if index_type == "FLAT":
1465-
return FlatIndex(uri=array_uri, dtype=vector_type, ctx=ctx)
1490+
return FlatIndex(uri=array_uri, dtype=vector_type, config=config)
14661491
elif index_type == "IVF_FLAT":
1467-
return IVFFlatIndex(uri=array_uri, dtype=vector_type, memory_budget=1000000, ctx=ctx)
1492+
return IVFFlatIndex(
1493+
uri=array_uri, dtype=vector_type, memory_budget=1000000, config=config
1494+
)

apis/python/src/tiledb/vector_search/module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def load_as_matrix(path: str, nqueries: int = 0, ctx: "Ctx" = None):
1818
Array path
1919
nqueries: int
2020
Number of queries
21-
config: Dict
22-
TileDB configuration parameters
21+
ctx: Ctx
22+
TileDB context
2323
"""
2424
if ctx is None:
2525
ctx = Ctx({})
@@ -44,7 +44,7 @@ def load_as_matrix(path: str, nqueries: int = 0, ctx: "Ctx" = None):
4444
return m
4545

4646

47-
def load_as_array(path, return_matrix: bool = False, ctx: "Ctx" = None):
47+
def load_as_array(path, return_matrix: bool = False, ctx: "Ctx" = None):
4848
"""
4949
Load array as array class
5050

apis/python/test/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def create_random_dataset_f32(nb, d, nq, k, path):
9696
from sklearn.datasets import make_blobs
9797
from sklearn.neighbors import NearestNeighbors
9898

99-
#print(f"Preparing datasets with {nb} random points and {nq} queries.")
99+
# print(f"Preparing datasets with {nb} random points and {nq} queries.")
100100
os.mkdir(path)
101101
X, _ = make_blobs(n_samples=nb + nq, n_features=d, centers=nq, random_state=1)
102102

@@ -111,7 +111,7 @@ def create_random_dataset_f32(nb, d, nq, k, path):
111111
np.array([nq, d], dtype="uint32").tofile(f)
112112
queries.astype("float32").tofile(f)
113113

114-
#print("Computing groundtruth")
114+
# print("Computing groundtruth")
115115

116116
nbrs = NearestNeighbors(n_neighbors=k, metric="euclidean", algorithm="brute").fit(
117117
data
@@ -128,7 +128,7 @@ def create_random_dataset_u8(nb, d, nq, k, path):
128128
from sklearn.datasets import make_blobs
129129
from sklearn.neighbors import NearestNeighbors
130130

131-
#print(f"Preparing datasets with {nb} random points and {nq} queries.")
131+
# print(f"Preparing datasets with {nb} random points and {nq} queries.")
132132
os.mkdir(path)
133133
X, _ = make_blobs(n_samples=nb + nq, n_features=d, centers=nq, random_state=1)
134134

@@ -145,7 +145,7 @@ def create_random_dataset_u8(nb, d, nq, k, path):
145145
np.array([nq, d], dtype="uint32").tofile(f)
146146
queries.tofile(f)
147147

148-
#print("Computing groundtruth")
148+
# print("Computing groundtruth")
149149

150150
nbrs = NearestNeighbors(n_neighbors=k, metric="euclidean", algorithm="brute").fit(
151151
data

apis/python/test/test_ingestion.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88

99
MINIMUM_ACCURACY = 0.9
1010

11-
@pytest.mark.parametrize(
12-
"query_type", ["heap", "nth"]
13-
)
11+
12+
@pytest.mark.parametrize("query_type", ["heap", "nth"])
1413
def test_flat_ingestion_u8(tmp_path, query_type):
1514
dataset_dir = os.path.join(tmp_path, "dataset")
1615
array_uri = os.path.join(tmp_path, "array")
@@ -31,9 +30,8 @@ def test_flat_ingestion_u8(tmp_path, query_type):
3130
result = index.query(query_vectors, k=k, query_type=query_type)
3231
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
3332

34-
@pytest.mark.parametrize(
35-
"query_type", ["heap", "nth"]
36-
)
33+
34+
@pytest.mark.parametrize("query_type", ["heap", "nth"])
3735
def test_flat_ingestion_f32(tmp_path, query_type):
3836
dataset_dir = os.path.join(tmp_path, "dataset")
3937
array_uri = os.path.join(tmp_path, "array")
@@ -143,9 +141,7 @@ def test_ivf_flat_ingestion_f32(tmp_path):
143141
)
144142
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
145143

146-
result = index_ram.query(
147-
query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL
148-
)
144+
result = index_ram.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL)
149145
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
150146

151147

@@ -189,7 +185,5 @@ def test_ivf_flat_ingestion_fvec(tmp_path):
189185
)
190186
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
191187

192-
result = index_ram.query(
193-
query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL
194-
)
188+
result = index_ram.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL)
195189
assert accuracy(result, gt_i) > MINIMUM_ACCURACY

0 commit comments

Comments
 (0)