Skip to content

Commit 18e4818

Browse files
author
Nikos Papailiou
committed
Add utils to clear update history
1 parent 45cf7fe commit 18e4818

File tree

2 files changed

+164
-3
lines changed

2 files changed

+164
-3
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,15 @@ def delete_batch(self, external_ids: np.array, timestamp: int = None):
241241

242242
def consolidate_update_fragments(self):
243243
fragments_info = tiledb.array_fragments(self.updates_array_uri)
244-
if len(fragments_info) > 10:
245-
tiledb.consolidate(self.updates_array_uri)
246-
tiledb.vacuum(self.updates_array_uri)
244+
count_fragments = 0
245+
for timestamp_range in fragments_info.timestamp_range:
246+
if timestamp_range[1] > self.latest_ingestion_timestamp:
247+
count_fragments += 1
248+
if count_fragments > 10:
249+
conf = tiledb.Config(self.config)
250+
conf["sm.consolidation.timestamp_start"] = self.latest_ingestion_timestamp
251+
tiledb.consolidate(self.updates_array_uri, config=conf)
252+
tiledb.vacuum(self.updates_array_uri, config=conf)
247253

248254
def get_updates_uri(self):
249255
return self.updates_array_uri
@@ -290,6 +296,13 @@ def consolidate_updates(self):
290296
for fragment_info in fragments_info:
291297
if fragment_info.timestamp_range[1] > max_timestamp:
292298
max_timestamp = fragment_info.timestamp_range[1]
299+
max_timestamp += 1
300+
conf = tiledb.Config(self.config)
301+
conf["sm.consolidation.timestamp_start"] = self.latest_ingestion_timestamp
302+
conf["sm.consolidation.timestamp_end"] = max_timestamp
303+
tiledb.consolidate(self.updates_array_uri, config=conf)
304+
tiledb.vacuum(self.updates_array_uri, config=conf)
305+
293306
new_index = ingest(
294307
index_type=self.index_type,
295308
index_uri=self.uri,
@@ -314,3 +327,64 @@ def delete_index(uri, config):
314327
else:
315328
raise err
316329
group.delete()
330+
331+
@staticmethod
332+
def clear_history(
333+
uri: str,
334+
timestamp: int,
335+
config: Optional[Mapping[str, Any]] = None,
336+
):
337+
group = tiledb.Group(uri, "r", ctx=tiledb.Ctx(config))
338+
storage_version = group.meta.get("storage_version", "0.1")
339+
if not storage_formats[storage_version]["SUPPORT_TIMETRAVEL"]:
340+
raise ValueError(f"Time traveling is not supported for index storage_version={storage_version}")
341+
ingestion_timestamps = [int(x) for x in
342+
list(json.loads(group.meta.get("ingestion_timestamps", "[]")))]
343+
base_sizes = [int(x) for x in list(json.loads(group.meta.get("base_sizes", "[]")))]
344+
new_ingestion_timestamps = []
345+
new_base_sizes = []
346+
i = 0
347+
for ingestion_timestamp in ingestion_timestamps:
348+
if ingestion_timestamp > timestamp:
349+
new_ingestion_timestamps.append(ingestion_timestamp)
350+
new_base_sizes.append(base_sizes[i])
351+
i += 1
352+
if len(new_ingestion_timestamps) == 0:
353+
new_ingestion_timestamps = [0]
354+
new_base_sizes = [1]
355+
index_type = group.meta.get("index_type", "")
356+
group.close()
357+
358+
group = tiledb.Group(uri, "w", ctx=tiledb.Ctx(config))
359+
group.meta["ingestion_timestamps"] = json.dumps(new_ingestion_timestamps)
360+
group.meta["base_sizes"] = json.dumps(new_base_sizes)
361+
group.close()
362+
363+
group = tiledb.Group(uri, "r", ctx=tiledb.Ctx(config))
364+
if storage_formats[storage_version]["UPDATES_ARRAY_NAME"] in group:
365+
updates_array_uri = group[storage_formats[storage_version]["UPDATES_ARRAY_NAME"]].uri
366+
with tiledb.open(updates_array_uri, 'm') as A:
367+
A.delete_fragments(0, timestamp)
368+
369+
if index_type == "FLAT":
370+
db_uri = group[storage_formats[storage_version]["PARTS_ARRAY_NAME"]].uri
371+
with tiledb.open(db_uri, 'm') as A:
372+
A.delete_fragments(0, timestamp)
373+
if storage_formats[storage_version]["IDS_ARRAY_NAME"] in group:
374+
ids_uri = group[storage_formats[storage_version]["IDS_ARRAY_NAME"]].uri
375+
with tiledb.open(ids_uri, 'm') as A:
376+
A.delete_fragments(0, timestamp)
377+
elif index_type == "IVF_FLAT":
378+
db_uri = group[storage_formats[storage_version]["PARTS_ARRAY_NAME"]].uri
379+
centroids_uri = group[storage_formats[storage_version]["CENTROIDS_ARRAY_NAME"]].uri
380+
index_array_uri = group[storage_formats[storage_version]["INDEX_ARRAY_NAME"]].uri
381+
ids_uri = group[storage_formats[storage_version]["IDS_ARRAY_NAME"]].uri
382+
with tiledb.open(db_uri, 'm') as A:
383+
A.delete_fragments(0, timestamp)
384+
with tiledb.open(centroids_uri, 'm') as A:
385+
A.delete_fragments(0, timestamp)
386+
with tiledb.open(index_array_uri, 'm') as A:
387+
A.delete_fragments(0, timestamp)
388+
with tiledb.open(ids_uri, 'm') as A:
389+
A.delete_fragments(0, timestamp)
390+
group.close()

apis/python/test/test_ingestion.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,15 +403,27 @@ def test_ivf_flat_ingestion_with_updates_and_timetravel(tmp_path):
403403
index.update(vector=data[i].astype(dtype), external_id=i + update_ids_offset, timestamp=i)
404404
updated_ids[i] = i + update_ids_offset
405405

406+
index = IVFFlatIndex(uri=index_uri)
407+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
408+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
409+
index = IVFFlatIndex(uri=index_uri)
410+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
411+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
406412
index = IVFFlatIndex(uri=index_uri, timestamp=101)
407413
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
408414
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
409415
index = IVFFlatIndex(uri=index_uri, timestamp=(0, 101))
410416
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
411417
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
418+
index = IVFFlatIndex(uri=index_uri, timestamp=(0, None))
419+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
420+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
412421
index = IVFFlatIndex(uri=index_uri, timestamp=(2, 101))
413422
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
414423
assert 0.05 <= accuracy(result, gt_i, updated_ids=updated_ids, only_updated_ids=True) <= 0.15
424+
index = IVFFlatIndex(uri=index_uri, timestamp=(2, None))
425+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
426+
assert 0.05 <= accuracy(result, gt_i, updated_ids=updated_ids, only_updated_ids=True) <= 0.15
415427

416428
# Timetravel with partial read from updates table
417429
updated_ids_part = {}
@@ -434,15 +446,24 @@ def test_ivf_flat_ingestion_with_updates_and_timetravel(tmp_path):
434446

435447
# Consolidate updates
436448
index = index.consolidate_updates()
449+
index = IVFFlatIndex(uri=index_uri)
450+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
451+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
437452
index = IVFFlatIndex(uri=index_uri, timestamp=101)
438453
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
439454
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
440455
index = IVFFlatIndex(uri=index_uri, timestamp=(0, 101))
441456
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
442457
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
458+
index = IVFFlatIndex(uri=index_uri, timestamp=(0, None))
459+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
460+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
443461
index = IVFFlatIndex(uri=index_uri, timestamp=(2, 101))
444462
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
445463
assert 0.05 <= accuracy(result, gt_i, updated_ids=updated_ids, only_updated_ids=True) <= 0.15
464+
index = IVFFlatIndex(uri=index_uri, timestamp=(2, None))
465+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
466+
assert 0.05 <= accuracy(result, gt_i, updated_ids=updated_ids, only_updated_ids=True) <= 0.15
446467

447468
# Timetravel with partial read from updates table
448469
updated_ids_part = {}
@@ -466,6 +487,72 @@ def test_ivf_flat_ingestion_with_updates_and_timetravel(tmp_path):
466487
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
467488
assert accuracy(result, gt_i) == 1.0
468489

490+
# Clear history before the latest ingestion
491+
Index.clear_history(uri=index_uri, timestamp=index.latest_ingestion_timestamp-1)
492+
index = IVFFlatIndex(uri=index_uri, timestamp=1)
493+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
494+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
495+
index = IVFFlatIndex(uri=index_uri, timestamp=51)
496+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
497+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
498+
index = IVFFlatIndex(uri=index_uri, timestamp=101)
499+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
500+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
501+
index = IVFFlatIndex(uri=index_uri)
502+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
503+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
504+
index = IVFFlatIndex(uri=index_uri, timestamp=(0, 51))
505+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
506+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
507+
index = IVFFlatIndex(uri=index_uri, timestamp=(0, 101))
508+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
509+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
510+
index = IVFFlatIndex(uri=index_uri, timestamp=(0, None))
511+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
512+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
513+
index = IVFFlatIndex(uri=index_uri, timestamp=(2, 51))
514+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
515+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
516+
index = IVFFlatIndex(uri=index_uri, timestamp=(2, 101))
517+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
518+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
519+
index = IVFFlatIndex(uri=index_uri, timestamp=(2, None))
520+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
521+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
522+
523+
# Clear all history
524+
Index.clear_history(uri=index_uri, timestamp=index.latest_ingestion_timestamp)
525+
index = IVFFlatIndex(uri=index_uri, timestamp=1)
526+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
527+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 0.0
528+
index = IVFFlatIndex(uri=index_uri, timestamp=51)
529+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
530+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 0.0
531+
index = IVFFlatIndex(uri=index_uri, timestamp=101)
532+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
533+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 0.0
534+
index = IVFFlatIndex(uri=index_uri)
535+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
536+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 0.0
537+
index = IVFFlatIndex(uri=index_uri, timestamp=(0, 51))
538+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
539+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 0.0
540+
index = IVFFlatIndex(uri=index_uri, timestamp=(0, 101))
541+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
542+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 0.0
543+
index = IVFFlatIndex(uri=index_uri, timestamp=(0, None))
544+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
545+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 0.0
546+
index = IVFFlatIndex(uri=index_uri, timestamp=(2, 51))
547+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
548+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 0.0
549+
index = IVFFlatIndex(uri=index_uri, timestamp=(2, 101))
550+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
551+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 0.0
552+
index = IVFFlatIndex(uri=index_uri, timestamp=(2, None))
553+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
554+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 0.0
555+
469556

470557
def test_ivf_flat_ingestion_with_additions_and_timetravel(tmp_path):
471558
dataset_dir = os.path.join(tmp_path, "dataset")

0 commit comments

Comments
 (0)