Skip to content

Commit 037b43a

Browse files
author
Nikos Papailiou
committed
Fix tests
1 parent 76019d8 commit 037b43a

File tree

3 files changed

+130
-64
lines changed

3 files changed

+130
-64
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def __init__(
5858
self.base_array_timestamp = self.latest_ingestion_timestamp
5959
self.query_base_array = True
6060
self.update_array_timestamp = (self.base_array_timestamp+1, None)
61-
if timestamp is not None:
61+
if timestamp is None:
62+
self.base_array_timestamp = 0
63+
else:
6264
if isinstance(timestamp, tuple):
6365
if len(timestamp) != 2:
6466
raise ValueError("'timestamp' argument expects either int or tuple(start: int, end: int)")

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 102 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,6 @@ def ingest(
107107

108108
# use index_group_uri for internal clarity
109109
index_group_uri = index_uri
110-
if index_timestamp is None:
111-
index_timestamp = int(time.time() * 1000)
112110

113111
CENTROIDS_ARRAY_NAME = storage_formats[STORAGE_VERSION]["CENTROIDS_ARRAY_NAME"]
114112
INDEX_ARRAY_NAME = storage_formats[STORAGE_VERSION]["INDEX_ARRAY_NAME"]
@@ -1187,21 +1185,37 @@ def ingest_vectors_udf(
11871185
)
11881186
if source_type == "TILEDB_ARRAY":
11891187
logger.debug("Start indexing")
1190-
ivf_index_tdb(
1191-
dtype=vector_type,
1192-
db_uri=source_uri,
1193-
external_ids_uri=external_ids_uri,
1194-
deleted_ids=StdVector_u64(updated_ids),
1195-
centroids_uri=centroids_uri,
1196-
parts_uri=partial_write_array_parts_uri,
1197-
index_array_uri=partial_write_array_index_uri,
1198-
id_uri=partial_write_array_ids_uri,
1199-
start=part,
1200-
end=part_end,
1201-
nthreads=threads,
1202-
timestamp=index_timestamp,
1203-
config=config,
1204-
)
1188+
if index_timestamp is None:
1189+
ivf_index_tdb(
1190+
dtype=vector_type,
1191+
db_uri=source_uri,
1192+
external_ids_uri=external_ids_uri,
1193+
deleted_ids=StdVector_u64(updated_ids),
1194+
centroids_uri=centroids_uri,
1195+
parts_uri=partial_write_array_parts_uri,
1196+
index_array_uri=partial_write_array_index_uri,
1197+
id_uri=partial_write_array_ids_uri,
1198+
start=part,
1199+
end=part_end,
1200+
nthreads=threads,
1201+
config=config,
1202+
)
1203+
else:
1204+
ivf_index_tdb(
1205+
dtype=vector_type,
1206+
db_uri=source_uri,
1207+
external_ids_uri=external_ids_uri,
1208+
deleted_ids=StdVector_u64(updated_ids),
1209+
centroids_uri=centroids_uri,
1210+
parts_uri=partial_write_array_parts_uri,
1211+
index_array_uri=partial_write_array_index_uri,
1212+
id_uri=partial_write_array_ids_uri,
1213+
start=part,
1214+
end=part_end,
1215+
nthreads=threads,
1216+
timestamp=index_timestamp,
1217+
config=config,
1218+
)
12051219
else:
12061220
in_vectors = read_input_vectors(
12071221
source_uri=source_uri,
@@ -1224,21 +1238,37 @@ def ingest_vectors_udf(
12241238
trace_id=trace_id,
12251239
)
12261240
logger.debug("Start indexing")
1227-
ivf_index(
1228-
dtype=vector_type,
1229-
db=array_to_matrix(np.transpose(in_vectors).astype(vector_type)),
1230-
external_ids=StdVector_u64(external_ids),
1231-
deleted_ids=StdVector_u64(updated_ids),
1232-
centroids_uri=centroids_uri,
1233-
parts_uri=partial_write_array_parts_uri,
1234-
index_array_uri=partial_write_array_index_uri,
1235-
id_uri=partial_write_array_ids_uri,
1236-
start=part,
1237-
end=part_end,
1238-
nthreads=threads,
1239-
timestamp=index_timestamp,
1240-
config=config,
1241-
)
1241+
if index_timestamp is None:
1242+
ivf_index(
1243+
dtype=vector_type,
1244+
db=array_to_matrix(np.transpose(in_vectors).astype(vector_type)),
1245+
external_ids=StdVector_u64(external_ids),
1246+
deleted_ids=StdVector_u64(updated_ids),
1247+
centroids_uri=centroids_uri,
1248+
parts_uri=partial_write_array_parts_uri,
1249+
index_array_uri=partial_write_array_index_uri,
1250+
id_uri=partial_write_array_ids_uri,
1251+
start=part,
1252+
end=part_end,
1253+
nthreads=threads,
1254+
config=config,
1255+
)
1256+
else:
1257+
ivf_index(
1258+
dtype=vector_type,
1259+
db=array_to_matrix(np.transpose(in_vectors).astype(vector_type)),
1260+
external_ids=StdVector_u64(external_ids),
1261+
deleted_ids=StdVector_u64(updated_ids),
1262+
centroids_uri=centroids_uri,
1263+
parts_uri=partial_write_array_parts_uri,
1264+
index_array_uri=partial_write_array_index_uri,
1265+
id_uri=partial_write_array_ids_uri,
1266+
start=part,
1267+
end=part_end,
1268+
nthreads=threads,
1269+
timestamp=index_timestamp,
1270+
config=config,
1271+
)
12421272

12431273
def ingest_additions_udf(
12441274
index_group_uri: str,
@@ -1281,21 +1311,37 @@ def ingest_additions_udf(
12811311
trace_id=trace_id,
12821312
)
12831313
logger.debug(f"Ingesting additions {partial_write_array_index_uri}")
1284-
ivf_index(
1285-
dtype=vector_type,
1286-
db=array_to_matrix(np.transpose(additions_vectors).astype(vector_type)),
1287-
external_ids=StdVector_u64(additions_external_ids),
1288-
deleted_ids=StdVector_u64(np.array([], np.uint64)),
1289-
centroids_uri=centroids_uri,
1290-
parts_uri=partial_write_array_parts_uri,
1291-
index_array_uri=partial_write_array_index_uri,
1292-
id_uri=partial_write_array_ids_uri,
1293-
start=write_offset,
1294-
end=0,
1295-
nthreads=threads,
1296-
timestamp=index_timestamp,
1297-
config=config,
1298-
)
1314+
if index_timestamp is None:
1315+
ivf_index(
1316+
dtype=vector_type,
1317+
db=array_to_matrix(np.transpose(additions_vectors).astype(vector_type)),
1318+
external_ids=StdVector_u64(additions_external_ids),
1319+
deleted_ids=StdVector_u64(np.array([], np.uint64)),
1320+
centroids_uri=centroids_uri,
1321+
parts_uri=partial_write_array_parts_uri,
1322+
index_array_uri=partial_write_array_index_uri,
1323+
id_uri=partial_write_array_ids_uri,
1324+
start=write_offset,
1325+
end=0,
1326+
nthreads=threads,
1327+
config=config,
1328+
)
1329+
else:
1330+
ivf_index(
1331+
dtype=vector_type,
1332+
db=array_to_matrix(np.transpose(additions_vectors).astype(vector_type)),
1333+
external_ids=StdVector_u64(additions_external_ids),
1334+
deleted_ids=StdVector_u64(np.array([], np.uint64)),
1335+
centroids_uri=centroids_uri,
1336+
parts_uri=partial_write_array_parts_uri,
1337+
index_array_uri=partial_write_array_index_uri,
1338+
id_uri=partial_write_array_ids_uri,
1339+
start=write_offset,
1340+
end=0,
1341+
nthreads=threads,
1342+
timestamp=index_timestamp,
1343+
config=config,
1344+
)
12991345

13001346
def compute_partition_indexes_udf(
13011347
index_group_uri: str,
@@ -1794,7 +1840,6 @@ def consolidate_and_vacuum(
17941840
raise ValueError(f"New ingestion timestamp: {index_timestamp} can't be smaller that the latest ingestion "
17951841
f"timestamp: {previous_ingestion_timestamp}")
17961842

1797-
ingestion_timestamps.append(index_timestamp)
17981843
group.close()
17991844
group = tiledb.Group(index_group_uri, "w")
18001845

@@ -1818,9 +1863,9 @@ def consolidate_and_vacuum(
18181863
source_uri=source_uri, source_type=source_type
18191864
)
18201865
if size == -1:
1821-
size = in_size
1866+
size = int(in_size)
18221867
if size > in_size:
1823-
size = in_size
1868+
size = int(in_size)
18241869
base_sizes.append(size)
18251870
logger.debug("Input dataset size %d", size)
18261871
logger.debug("Input dataset dimensions %d", dimensions)
@@ -1841,7 +1886,6 @@ def consolidate_and_vacuum(
18411886
group.meta["dtype"] = np.dtype(vector_type).name
18421887
group.meta["partitions"] = partitions
18431888
group.meta["storage_version"] = STORAGE_VERSION
1844-
group.meta["ingestion_timestamps"] = json.dumps(ingestion_timestamps)
18451889
group.meta["base_sizes"] = json.dumps(base_sizes)
18461890

18471891
if external_ids is not None:
@@ -1939,6 +1983,13 @@ def consolidate_and_vacuum(
19391983
d.compute()
19401984
logger.debug("Submitted ingestion graph")
19411985
d.wait()
1986+
1987+
group = tiledb.Group(index_group_uri, "w")
1988+
if index_timestamp is None:
1989+
index_timestamp = int(time.time() * 1000)
1990+
ingestion_timestamps.append(index_timestamp)
1991+
group.meta["ingestion_timestamps"] = json.dumps(ingestion_timestamps)
1992+
group.close()
19421993
consolidate_and_vacuum(index_group_uri=index_group_uri, config=config)
19431994

19441995
if index_type == "FLAT":

apis/python/src/tiledb/vector_search/ivf_flat_index.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -236,18 +236,31 @@ def dist_qv_udf(
236236
timestamp: int = 0,
237237
):
238238
queries_m = array_to_matrix(np.transpose(query_vectors))
239-
r = dist_qv(
240-
dtype=dtype,
241-
parts_uri=parts_uri,
242-
ids_uri=ids_uri,
243-
query_vectors=queries_m,
244-
active_partitions=active_partitions,
245-
active_queries=active_queries,
246-
indices=indices,
247-
k_nn=k_nn,
248-
ctx=Ctx(config),
249-
timestamp=timestamp,
250-
)
239+
if timestamp == 0:
240+
r = dist_qv(
241+
dtype=dtype,
242+
parts_uri=parts_uri,
243+
ids_uri=ids_uri,
244+
query_vectors=queries_m,
245+
active_partitions=active_partitions,
246+
active_queries=active_queries,
247+
indices=indices,
248+
k_nn=k_nn,
249+
ctx=Ctx(config),
250+
)
251+
else:
252+
r = dist_qv(
253+
dtype=dtype,
254+
parts_uri=parts_uri,
255+
ids_uri=ids_uri,
256+
query_vectors=queries_m,
257+
active_partitions=active_partitions,
258+
active_queries=active_queries,
259+
indices=indices,
260+
k_nn=k_nn,
261+
ctx=Ctx(config),
262+
timestamp=timestamp,
263+
)
251264
results = []
252265
for q in range(len(r)):
253266
tmp_results = []

0 commit comments

Comments
 (0)