Skip to content

Commit 334618c

Browse files
author
Nikos Papailiou
committed
Format
1 parent f4eac93 commit 334618c

File tree

5 files changed

+92
-48
lines changed

5 files changed

+92
-48
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
IDS_ARRAY_NAME = "ids.tdb"
1111
PARTS_ARRAY_NAME = "parts.tdb"
1212

13+
1314
def submit_local(d, func, *args, **kwargs):
1415
# Drop kwarg
1516
kwargs.pop("image_name", None)
1617
kwargs.pop("resources", None)
1718
return d.submit_local(func, *args, **kwargs)
1819

20+
1921
class Index:
2022
def query(self, targets: np.ndarray, k=10, nqueries=10, nthreads=8, nprobe=1):
2123
raise NotImplementedError
@@ -205,7 +207,11 @@ def distributed_query(
205207
"""
206208
from tiledb.cloud import dag
207209
from tiledb.cloud.dag import Mode
208-
from tiledb.vector_search.module import array_to_matrix, partition_ivf_index, dist_qv
210+
from tiledb.vector_search.module import (
211+
array_to_matrix,
212+
partition_ivf_index,
213+
dist_qv,
214+
)
209215
import math
210216
import numpy as np
211217
from functools import partial
@@ -218,7 +224,8 @@ def dist_qv_udf(
218224
active_partitions: np.array,
219225
active_queries: np.array,
220226
indices: np.array,
221-
k_nn: int):
227+
k_nn: int,
228+
):
222229
targets_m = array_to_matrix(query_vectors)
223230
r = dist_qv(
224231
dtype=dtype,
@@ -264,10 +271,8 @@ def dist_qv_udf(
264271

265272
targets_m = array_to_matrix(targets)
266273
active_partitions, active_queries = partition_ivf_index(
267-
centroids=self._centroids,
268-
query=targets_m,
269-
nprobe=nprobe,
270-
nthreads=nthreads)
274+
centroids=self._centroids, query=targets_m, nprobe=nprobe, nthreads=nthreads
275+
)
271276
num_parts = len(active_partitions)
272277

273278
parts_per_node = int(math.ceil(num_parts / num_nodes))
@@ -276,19 +281,23 @@ def dist_qv_udf(
276281
part_end = part + parts_per_node
277282
if part_end > num_parts:
278283
part_end = num_parts
279-
nodes.append(submit(
280-
dist_qv_udf,
281-
dtype=self.dtype,
282-
parts_uri=self.parts_db_uri,
283-
ids_uri=self.ids_uri,
284-
query_vectors=targets,
285-
active_partitions=np.array(active_partitions)[part:part_end],
286-
active_queries=np.array(active_queries[part:part_end], dtype=object),
287-
indices=np.array(self._index),
288-
k_nn=k,
289-
resource_class='large',
290-
image_name="3.9-vectorsearch",
291-
))
284+
nodes.append(
285+
submit(
286+
dist_qv_udf,
287+
dtype=self.dtype,
288+
parts_uri=self.parts_db_uri,
289+
ids_uri=self.ids_uri,
290+
query_vectors=targets,
291+
active_partitions=np.array(active_partitions)[part:part_end],
292+
active_queries=np.array(
293+
active_queries[part:part_end], dtype=object
294+
),
295+
indices=np.array(self._index),
296+
k_nn=k,
297+
resource_class="large",
298+
image_name="3.9-vectorsearch",
299+
)
300+
)
292301

293302
d.compute()
294303
d.wait()

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,9 @@ def read_input_vectors(
422422
trace_id: Optional[str] = None,
423423
) -> np.array:
424424
logger = setup(config, verbose)
425-
logger.debug("Reading input vectors start_pos: %i, end_pos: %i", start_pos, end_pos)
425+
logger.debug(
426+
"Reading input vectors start_pos: %i, end_pos: %i", start_pos, end_pos
427+
)
426428
if source_type == "TILEDB_ARRAY":
427429
with tiledb.open(source_uri, mode="r") as src_array:
428430
return np.transpose(
@@ -511,7 +513,9 @@ def copy_centroids(
511513
logger = setup(config, verbose)
512514
group = tiledb.Group(array_uri)
513515
centroids_uri = group[CENTROIDS_ARRAY_NAME].uri
514-
logger.debug("Copying centroids from: %s, to: %s", copy_centroids_uri, centroids_uri)
516+
logger.debug(
517+
"Copying centroids from: %s, to: %s", copy_centroids_uri, centroids_uri
518+
)
515519
src = tiledb.open(copy_centroids_uri, mode="r")
516520
dest = tiledb.open(centroids_uri, mode="w")
517521
src_centroids = src[:, :]
@@ -586,7 +590,9 @@ def init_centroids(
586590
trace_id: Optional[str] = None,
587591
) -> np.array:
588592
logger = setup(config, verbose)
589-
logger.debug("Initialising centroids by reading the first vectors in the source data.")
593+
logger.debug(
594+
"Initialising centroids by reading the first vectors in the source data."
595+
)
590596
with tiledb.scope_ctx(ctx_or_config=config):
591597
return read_input_vectors(
592598
source_uri=source_uri,
@@ -922,7 +928,9 @@ def consolidate_partition_udf(
922928
):
923929
logger = setup(config, verbose)
924930
with tiledb.scope_ctx(ctx_or_config=config):
925-
logger.debug("Consolidating partitions %d-%d", partition_id_start, partition_id_end)
931+
logger.debug(
932+
"Consolidating partitions %d-%d", partition_id_start, partition_id_end
933+
)
926934
group = tiledb.Group(array_uri)
927935
partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
928936
partial_write_array_ids_uri = (
@@ -962,12 +970,16 @@ def consolidate_partition_udf(
962970
index_array = tiledb.open(index_array_uri, mode="r")
963971
ids_array = tiledb.open(ids_array_uri, mode="w")
964972
parts_array = tiledb.open(parts_array_uri, mode="w")
965-
logger.debug("Partitions start: %d end: %d", partition_id_start, partition_id_end)
973+
logger.debug(
974+
"Partitions start: %d end: %d", partition_id_start, partition_id_end
975+
)
966976
for part in range(partition_id_start, partition_id_end, batch):
967977
part_end = part + batch
968978
if part_end > partition_id_end:
969979
part_end = partition_id_end
970-
logger.debug("Consolidating partitions start: %d end: %d", part, part_end)
980+
logger.debug(
981+
"Consolidating partitions start: %d end: %d", part, part_end
982+
)
971983
read_slices = []
972984
for p in range(part, part_end):
973985
for partition_slice in partition_slices[p]:
@@ -985,8 +997,13 @@ def consolidate_partition_udf(
985997
"values"
986998
]
987999

988-
logger.debug("Ids shape %s, expected size: %d expected range:(%d,%d)", ids.shape, end_pos - start_pos,
989-
start_pos, end_pos)
1000+
logger.debug(
1001+
"Ids shape %s, expected size: %d expected range:(%d,%d)",
1002+
ids.shape,
1003+
end_pos - start_pos,
1004+
start_pos,
1005+
end_pos,
1006+
)
9901007
if ids.shape[0] != end_pos - start_pos:
9911008
raise ValueError("Incorrect partition size.")
9921009

@@ -1338,7 +1355,10 @@ def consolidate_and_vacuum(
13381355
logger.debug("input_vectors_per_work_item %d", input_vectors_per_work_item)
13391356
logger.debug("input_vectors_work_items %d", input_vectors_work_items)
13401357
logger.debug("input_vectors_work_tasks %d", input_vectors_work_tasks)
1341-
logger.debug("input_vectors_work_items_per_worker %d", input_vectors_work_items_per_worker)
1358+
logger.debug(
1359+
"input_vectors_work_items_per_worker %d",
1360+
input_vectors_work_items_per_worker,
1361+
)
13421362

13431363
vectors_per_table_partitions = size / partitions
13441364
table_partitions_per_work_item = int(
@@ -1354,10 +1374,15 @@ def consolidate_and_vacuum(
13541374
math.ceil(table_partitions_work_items / MAX_TASKS_PER_STAGE)
13551375
)
13561376
table_partitions_work_tasks = MAX_TASKS_PER_STAGE
1357-
logger.debug("table_partitions_per_work_item %d", table_partitions_per_work_item)
1377+
logger.debug(
1378+
"table_partitions_per_work_item %d", table_partitions_per_work_item
1379+
)
13581380
logger.debug("table_partitions_work_items %d", table_partitions_work_items)
13591381
logger.debug("table_partitions_work_tasks %d", table_partitions_work_tasks)
1360-
logger.debug("table_partitions_work_items_per_worker %d", table_partitions_work_items_per_worker)
1382+
logger.debug(
1383+
"table_partitions_work_items_per_worker %d",
1384+
table_partitions_work_items_per_worker,
1385+
)
13611386

13621387
logger.debug("Creating arrays")
13631388
create_arrays(

apis/python/src/tiledb/vector_search/module.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ def partition_ivf_index(centroids, query, nprobe=1, nthreads=0):
302302
else:
303303
raise TypeError("Unsupported type!")
304304

305+
305306
def dist_qv(
306307
dtype: np.dtype,
307308
parts_uri: str,
@@ -311,7 +312,8 @@ def dist_qv(
311312
active_queries: np.array,
312313
indices: np.array,
313314
k_nn: int,
314-
ctx: "Ctx" = None):
315+
ctx: "Ctx" = None,
316+
):
315317
if ctx is None:
316318
ctx = Ctx({})
317319
args = tuple(
@@ -323,7 +325,7 @@ def dist_qv(
323325
active_queries,
324326
StdVector_u64(indices),
325327
ids_uri,
326-
k_nn
328+
k_nn,
327329
]
328330
)
329331
if dtype == np.float32:
@@ -333,6 +335,7 @@ def dist_qv(
333335
else:
334336
raise TypeError("Unsupported type!")
335337

338+
336339
def validate_top_k(results: np.ndarray, ground_truth: np.ndarray):
337340
if results.dtype == np.uint64:
338341
return validate_top_k_u64(results, ground_truth)

apis/python/test/common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,11 @@ def create_array(path: str, data):
182182
with tiledb.open(path, "w") as A:
183183
A[:] = data
184184

185+
185186
def accuracy(result, gt):
186187
found = 0
187188
total = 0
188189
for i in range(len(result)):
189-
total+=len(result[i])
190-
found+=len(np.intersect1d(result[i], gt[i]))
191-
return found/total
190+
total += len(result[i])
191+
found += len(np.intersect1d(result[i], gt[i]))
192+
return found / total

apis/python/test/test_ingestion.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
MINIMUM_ACCURACY = 0.9
88

9+
910
def test_flat_ingestion_u8(tmp_path):
1011
dataset_dir = os.path.join(tmp_path, "dataset")
1112
array_uri = os.path.join(tmp_path, "array")
@@ -47,6 +48,7 @@ def test_flat_ingestion_f32(tmp_path):
4748
result = np.transpose(index.query(np.transpose(query_vectors), k=k))
4849
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
4950

51+
5052
def test_ivf_flat_ingestion_u8(tmp_path):
5153
dataset_dir = os.path.join(tmp_path, "dataset")
5254
array_uri = os.path.join(tmp_path, "array")
@@ -70,9 +72,7 @@ def test_ivf_flat_ingestion_u8(tmp_path):
7072
partitions=partitions,
7173
input_vectors_per_work_item=int(size / 10),
7274
)
73-
result = np.transpose(
74-
index.query(np.transpose(query_vectors), k=k, nprobe=nprobe)
75-
)
75+
result = np.transpose(index.query(np.transpose(query_vectors), k=k, nprobe=nprobe))
7676
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
7777

7878
index_ram = IVFFlatIndex(uri=array_uri, dtype=dtype, memory_budget=int(size / 10))
@@ -90,11 +90,18 @@ def test_ivf_flat_ingestion_u8(tmp_path):
9090
)
9191
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
9292

93-
result = index_ram.distributed_query(np.transpose(query_vectors.astype(np.uint8)), k=k, nprobe=nprobe, mode=Mode.LOCAL)
93+
result = index_ram.distributed_query(
94+
np.transpose(query_vectors.astype(np.uint8)),
95+
k=k,
96+
nprobe=nprobe,
97+
mode=Mode.LOCAL,
98+
)
9499
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
95100

101+
96102
def test_ivf_flat_ingestion_f32(tmp_path):
97103
import time
104+
98105
dataset_dir = os.path.join(tmp_path, "dataset")
99106
array_uri = os.path.join(tmp_path, "array")
100107
k = 10
@@ -120,9 +127,7 @@ def test_ivf_flat_ingestion_f32(tmp_path):
120127
input_vectors_per_work_item=int(size / 10),
121128
)
122129

123-
result = np.transpose(
124-
index.query(np.transpose(query_vectors), k=k, nprobe=nprobe)
125-
)
130+
result = np.transpose(index.query(np.transpose(query_vectors), k=k, nprobe=nprobe))
126131
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
127132

128133
index_ram = IVFFlatIndex(uri=array_uri, dtype=dtype, memory_budget=int(size / 10))
@@ -140,11 +145,12 @@ def test_ivf_flat_ingestion_f32(tmp_path):
140145
)
141146
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
142147

143-
result = index_ram.distributed_query(np.transpose(query_vectors), k=k, nprobe=nprobe, mode=Mode.LOCAL)
148+
result = index_ram.distributed_query(
149+
np.transpose(query_vectors), k=k, nprobe=nprobe, mode=Mode.LOCAL
150+
)
144151
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
145152

146153

147-
148154
def test_ivf_flat_ingestion_fvec(tmp_path):
149155
source_uri = "test/data/siftsmall/siftsmall_base.fvecs"
150156
queries_uri = "test/data/siftsmall/siftsmall_query.fvecs"
@@ -170,9 +176,7 @@ def test_ivf_flat_ingestion_fvec(tmp_path):
170176
source_type=source_type,
171177
partitions=partitions,
172178
)
173-
result = np.transpose(
174-
index.query(np.transpose(query_vectors), k=k, nprobe=nprobe)
175-
)
179+
result = np.transpose(index.query(np.transpose(query_vectors), k=k, nprobe=nprobe))
176180
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
177181

178182
index_ram = IVFFlatIndex(uri=array_uri, dtype=dtype)
@@ -190,5 +194,7 @@ def test_ivf_flat_ingestion_fvec(tmp_path):
190194
)
191195
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
192196

193-
result = index_ram.distributed_query(np.transpose(query_vectors), k=k, nprobe=nprobe, mode=Mode.LOCAL)
197+
result = index_ram.distributed_query(
198+
np.transpose(query_vectors), k=k, nprobe=nprobe, mode=Mode.LOCAL
199+
)
194200
assert accuracy(result, gt_i) > MINIMUM_ACCURACY

0 commit comments

Comments
 (0)