Skip to content

Commit 53aa88a

Browse files
author
Nikos Papailiou
committed
Timetravel implementation
1 parent e5dbb5e commit 53aa88a

File tree

5 files changed

+172
-54
lines changed

5 files changed

+172
-54
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(
2929
self,
3030
uri: str,
3131
config: Optional[Mapping[str, Any]] = None,
32-
timestamp: int = None,
32+
timestamp=None,
3333
):
3434
# If the user passes a tiledb python Config object convert to a dictionary
3535
if isinstance(config, tiledb.Config):
@@ -40,13 +40,14 @@ def __init__(
4040
self.ctx = Ctx(config)
4141
self.group = tiledb.Group(self.uri, "r", ctx=tiledb.Ctx(config))
4242
self.storage_version = self.group.meta.get("storage_version", "0.1")
43-
self.update_arrays_uri = None
43+
updates_array_name = storage_formats[self.storage_version][
44+
"UPDATES_ARRAY_NAME"
45+
]
46+
self.updates_array_uri = f"{self.group.uri}/{updates_array_name}"
4447
self.index_version = self.group.meta.get("index_version", "")
45-
4648
self.ingestion_timestamps = list(json.loads(self.group.meta.get("ingestion_timestamps", "[]")))
47-
print(f"ingestion_timestamps: {self.ingestion_timestamps}")
48-
self.base_array_timestamp = self.ingestion_timestamps[len(self.ingestion_timestamps)-1]
49-
print(f"base_array_timestamp: {self.base_array_timestamp}")
49+
self.latest_ingestion_timestamp = self.ingestion_timestamps[len(self.ingestion_timestamps)-1]
50+
self.base_array_timestamp = self.latest_ingestion_timestamp
5051
self.query_base_array = True
5152
self.update_array_timestamp = (self.base_array_timestamp+1, None)
5253
if timestamp is not None:
@@ -70,13 +71,14 @@ def __init__(
7071
self.update_array_timestamp = (self.base_array_timestamp+1, timestamp)
7172
else:
7273
raise TypeError("Unexpected argument type for 'timestamp' keyword argument")
73-
print(f"base_array_timestamp: {self.base_array_timestamp}")
74-
print(f"update_array_timestamp: {self.update_array_timestamp}")
7574
self.thread_executor = futures.ThreadPoolExecutor()
7675

7776
def query(self, queries: np.ndarray, k, **kwargs):
78-
if self.update_arrays_uri is None:
79-
return self.query_internal(queries, k, **kwargs)
77+
if not tiledb.array_exists(self.updates_array_uri):
78+
if self.query_base_array:
79+
return self.query_internal(queries, k, **kwargs)
80+
else:
81+
return np.full((queries.shape[0], k), MAX_FLOAT_32), np.full((queries.shape[0], k), MAX_UINT64)
8082

8183
# Query with updates
8284
# Perform the queries in parallel
@@ -87,13 +89,17 @@ def query(self, queries: np.ndarray, k, **kwargs):
8789
queries,
8890
k,
8991
self.dtype,
90-
self.update_arrays_uri,
92+
self.updates_array_uri,
9193
int(os.cpu_count() / 2),
9294
self.update_array_timestamp,
9395
)
94-
internal_results_d, internal_results_i = self.query_internal(
95-
queries, retrieval_k, **kwargs
96-
)
96+
if self.query_base_array:
97+
internal_results_d, internal_results_i = self.query_internal(
98+
queries, retrieval_k, **kwargs
99+
)
100+
else:
101+
internal_results_d = np.full((queries.shape[0], k), MAX_FLOAT_32)
102+
internal_results_i = np.full((queries.shape[0], k), MAX_UINT64)
97103
addition_results_d, addition_results_i, updated_ids = future.result()
98104

99105
# Filter updated vectors
@@ -142,11 +148,11 @@ def query(self, queries: np.ndarray, k, **kwargs):
142148

143149
@staticmethod
144150
def query_additions(
145-
queries: np.ndarray, k, dtype, update_arrays_uri, nthreads=8, timestamp=None
151+
queries: np.ndarray, k, dtype, updates_array_uri, nthreads=8, timestamp=None
146152
):
147153
assert queries.dtype == np.float32
148154
additions_vectors, additions_external_ids, updated_ids = Index.read_additions(
149-
update_arrays_uri, timestamp
155+
updates_array_uri, timestamp
150156
)
151157
if additions_vectors is None:
152158
return None, None, updated_ids
@@ -162,10 +168,10 @@ def query_additions(
162168
return np.transpose(np.array(d)), np.transpose(np.array(i)), updated_ids
163169

164170
@staticmethod
165-
def read_additions(update_arrays_uri, timestamp=None) -> (np.ndarray, np.array):
166-
if update_arrays_uri is None:
171+
def read_additions(updates_array_uri, timestamp=None) -> (np.ndarray, np.array):
172+
if updates_array_uri is None:
167173
return None, None, np.array([], np.uint64)
168-
updates_array = tiledb.open(update_arrays_uri, mode="r", timestamp=timestamp)
174+
updates_array = tiledb.open(updates_array_uri, mode="r", timestamp=timestamp)
169175
q = updates_array.query(attrs=("vector",), coords=True)
170176
data = q[:]
171177
updates_array.close()
@@ -215,22 +221,22 @@ def delete_batch(self, external_ids: np.array, timestamp: int = None):
215221
self.consolidate_update_fragments()
216222

217223
def consolidate_update_fragments(self):
218-
fragments_info = tiledb.array_fragments(self.update_arrays_uri)
224+
fragments_info = tiledb.array_fragments(self.updates_array_uri)
219225
if len(fragments_info) > 10:
220-
tiledb.consolidate(self.update_arrays_uri)
221-
tiledb.vacuum(self.update_arrays_uri)
226+
tiledb.consolidate(self.updates_array_uri)
227+
tiledb.vacuum(self.updates_array_uri)
222228

223229
def get_updates_uri(self):
224-
return self.update_arrays_uri
230+
return self.updates_array_uri
225231

226232
def open_updates_array(self, timestamp: int = None):
227-
if self.update_arrays_uri is None:
233+
if timestamp is not None and timestamp <= self.latest_ingestion_timestamp:
234+
raise ValueError(f"Updates at a timestamp before the latest_ingestion_timestamp are not supported. "
235+
f"timestamp: {timestamp}, latest_ingestion_timestamp: {self.latest_ingestion_timestamp}")
236+
if not tiledb.array_exists(self.updates_array_uri):
228237
updates_array_name = storage_formats[self.storage_version][
229238
"UPDATES_ARRAY_NAME"
230239
]
231-
updates_array_uri = f"{self.group.uri}/{updates_array_name}"
232-
if tiledb.array_exists(updates_array_uri):
233-
raise RuntimeError(f"Array {updates_array_uri} already exists.")
234240
external_id_dim = tiledb.Dim(
235241
name="external_id",
236242
domain=(0, MAX_UINT64 - 1),
@@ -244,27 +250,32 @@ def open_updates_array(self, timestamp: int = None):
244250
attrs=[vector_attr],
245251
allows_duplicates=False,
246252
)
247-
tiledb.Array.create(updates_array_uri, updates_schema)
253+
tiledb.Array.create(self.updates_array_uri, updates_schema)
248254
self.group.close()
249255
self.group = tiledb.Group(self.uri, "w", ctx=tiledb.Ctx(self.config))
250-
self.group.add(updates_array_uri, name=updates_array_name)
256+
self.group.add(self.updates_array_uri, name=updates_array_name)
251257
self.group.close()
252258
self.group = tiledb.Group(self.uri, "r", ctx=tiledb.Ctx(self.config))
253-
self.update_arrays_uri = updates_array_uri
254259
if timestamp is None:
255260
timestamp = int(time.time() * 1000)
256-
return tiledb.open(self.update_arrays_uri, mode="w", timestamp=timestamp)
261+
return tiledb.open(self.updates_array_uri, mode="w", timestamp=timestamp)
257262

258263
def consolidate_updates(self):
259264
from tiledb.vector_search.ingestion import ingest
260265

266+
fragments_info = tiledb.array_fragments(self.updates_array_uri, ctx=tiledb.Ctx(self.config))
267+
max_timestamp = self.base_array_timestamp
268+
for fragment_info in fragments_info:
269+
if fragment_info.timestamp_range[1] > max_timestamp:
270+
max_timestamp = fragment_info.timestamp_range[1]
261271
new_index = ingest(
262272
index_type=self.index_type,
263273
index_uri=self.uri,
264274
size=self.size,
265275
source_uri=self.db_uri,
266276
external_ids_uri=self.ids_uri,
267-
updates_uri=self.update_arrays_uri,
277+
updates_uri=self.updates_array_uri,
278+
index_timestamp=max_timestamp,
279+
config=self.config,
268280
)
269-
new_index.update_arrays_uri = self.update_arrays_uri
270281
return new_index

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1785,8 +1785,14 @@ def consolidate_and_vacuum(
17851785
ingestion_timestamps = list(json.loads(group.meta.get("ingestion_timestamps", "[]")))
17861786
if partitions == -1:
17871787
partitions = int(group.meta.get("partitions", "-1"))
1788+
1789+
if len(ingestion_timestamps) > 0:
1790+
previous_ingestion_timestamp = ingestion_timestamps[len(ingestion_timestamps)-1]
1791+
if index_timestamp <= previous_ingestion_timestamp:
1792+
raise ValueError(f"New ingestion timestamp: {index_timestamp} can't be smaller that the latest ingestion "
1793+
f"timestamp: {previous_ingestion_timestamp}")
1794+
17881795
ingestion_timestamps.append(index_timestamp)
1789-
print(f"ingestion_timestamps: {ingestion_timestamps}")
17901796
group.close()
17911797
group = tiledb.Group(index_group_uri, "w")
17921798

apis/python/test/common.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,13 @@ def create_array(path: str, data):
162162
A[:] = data
163163

164164

165-
def accuracy(result, gt, external_ids_offset=0, updated_ids=None):
165+
def accuracy(result, gt, external_ids_offset=0, updated_ids=None, only_updated_ids=False):
166166
found = 0
167167
total = 0
168+
if updated_ids is not None:
169+
updated_ids_rev = {}
170+
for updated_id in updated_ids:
171+
updated_ids_rev[updated_ids[updated_id]] = updated_id
168172
for i in range(len(result)):
169173
if external_ids_offset != 0:
170174
temp_result = []
@@ -173,7 +177,12 @@ def accuracy(result, gt, external_ids_offset=0, updated_ids=None):
173177
elif updated_ids is not None:
174178
temp_result = []
175179
for j in range(len(result[i])):
176-
uid = updated_ids.get(result[i][j])
180+
if result[i][j] in updated_ids:
181+
raise ValueError(f"Found updated id {result[i][j]} in query results.")
182+
if only_updated_ids:
183+
if result[i][j] not in updated_ids_rev:
184+
raise ValueError(f"Found not_updated_id {result[i][j]} in query results while expecting only_updated_ids.")
185+
uid = updated_ids_rev.get(result[i][j])
177186
if uid is not None:
178187
temp_result.append(int(uid))
179188
else:

apis/python/test/test_ingestion.py

Lines changed: 109 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -288,11 +288,11 @@ def test_ivf_flat_ingestion_with_updates(tmp_path):
288288
dataset_dir = os.path.join(tmp_path, "dataset")
289289
index_uri = os.path.join(tmp_path, "array")
290290
k = 10
291-
size = 100000
292-
partitions = 100
291+
size = 1000
292+
partitions = 10
293293
dimensions = 128
294294
nqueries = 100
295-
nprobe = 20
295+
nprobe = 10
296296
data = create_random_dataset_u8(nb=size, d=dimensions, nq=nqueries, k=k, path=dataset_dir)
297297
dtype = np.uint8
298298

@@ -303,24 +303,23 @@ def test_ivf_flat_ingestion_with_updates(tmp_path):
303303
index_uri=index_uri,
304304
source_uri=os.path.join(dataset_dir, "data.u8bin"),
305305
partitions=partitions,
306-
input_vectors_per_work_item=int(size / 10),
307306
)
308307
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
309-
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
308+
assert accuracy(result, gt_i) == 1.0
310309

311310
update_ids_offset = MAX_UINT64-size
312311
updated_ids = {}
313312
for i in range(100):
314313
index.delete(external_id=i)
315314
index.update(vector=data[i].astype(dtype), external_id=i + update_ids_offset)
316-
updated_ids[i + update_ids_offset] = i
315+
updated_ids[i] = i + update_ids_offset
317316

318317
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
319-
assert accuracy(result, gt_i, updated_ids=updated_ids) > MINIMUM_ACCURACY
318+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
320319

321320
index = index.consolidate_updates()
322321
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
323-
assert accuracy(result, gt_i, updated_ids=updated_ids) > MINIMUM_ACCURACY
322+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
324323

325324
def test_ivf_flat_ingestion_with_batch_updates(tmp_path):
326325
dataset_dir = os.path.join(tmp_path, "dataset")
@@ -330,7 +329,7 @@ def test_ivf_flat_ingestion_with_batch_updates(tmp_path):
330329
partitions = 100
331330
dimensions = 128
332331
nqueries = 100
333-
nprobe = 20
332+
nprobe = 100
334333
data = create_random_dataset_u8(nb=size, d=dimensions, nq=nqueries, k=k, path=dataset_dir)
335334
dtype = np.uint8
336335

@@ -344,18 +343,18 @@ def test_ivf_flat_ingestion_with_batch_updates(tmp_path):
344343
input_vectors_per_work_item=int(size / 10),
345344
)
346345
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
347-
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
346+
assert accuracy(result, gt_i) > 0.99
348347

349348
update_ids = {}
350349
updated_ids = {}
351350
update_ids_offset = MAX_UINT64 - size
352351
for i in range(0, 100000, 2):
353-
update_ids[i] = i + update_ids_offset
354-
updated_ids[i + update_ids_offset] = i
355-
external_ids = np.zeros((len(update_ids) * 2), dtype=np.uint64)
356-
updates = np.empty((len(update_ids) * 2), dtype='O')
352+
updated_ids[i] = i + update_ids_offset
353+
update_ids[i + update_ids_offset] = i
354+
external_ids = np.zeros((len(updated_ids) * 2), dtype=np.uint64)
355+
updates = np.empty((len(updated_ids) * 2), dtype='O')
357356
id = 0
358-
for prev_id, new_id in update_ids.items():
357+
for prev_id, new_id in updated_ids.items():
359358
external_ids[id] = prev_id
360359
updates[id] = np.array([], dtype=dtype)
361360
id += 1
@@ -365,9 +364,102 @@ def test_ivf_flat_ingestion_with_batch_updates(tmp_path):
365364

366365
index.update_batch(vectors=updates, external_ids=external_ids)
367366
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
368-
assert accuracy(result, gt_i, updated_ids=updated_ids) > MINIMUM_ACCURACY
367+
assert accuracy(result, gt_i, updated_ids=updated_ids) > 0.99
369368

370369
index = index.consolidate_updates()
371370
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
372-
assert accuracy(result, gt_i, updated_ids=updated_ids) > MINIMUM_ACCURACY
371+
assert accuracy(result, gt_i, updated_ids=updated_ids) > 0.99
372+
373+
def test_ivf_flat_ingestion_with_updates_and_timetravel(tmp_path):
374+
dataset_dir = os.path.join(tmp_path, "dataset")
375+
index_uri = os.path.join(tmp_path, "array")
376+
k = 10
377+
size = 1000
378+
partitions = 10
379+
dimensions = 128
380+
nqueries = 100
381+
nprobe = 10
382+
data = create_random_dataset_u8(nb=size, d=dimensions, nq=nqueries, k=k, path=dataset_dir)
383+
dtype = np.uint8
384+
385+
query_vectors = get_queries(dataset_dir, dtype=dtype)
386+
gt_i, gt_d = get_groundtruth(dataset_dir, k)
387+
index = ingest(
388+
index_type="IVF_FLAT",
389+
index_uri=index_uri,
390+
source_uri=os.path.join(dataset_dir, "data.u8bin"),
391+
partitions=partitions,
392+
index_timestamp=1,
393+
)
394+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
395+
assert accuracy(result, gt_i) == 1.0
373396

397+
update_ids_offset = MAX_UINT64-size
398+
updated_ids = {}
399+
for i in range(2, 102):
400+
index.delete(external_id=i, timestamp=i)
401+
index.update(vector=data[i].astype(dtype), external_id=i + update_ids_offset, timestamp=i)
402+
updated_ids[i] = i + update_ids_offset
403+
404+
index = IVFFlatIndex(uri=index_uri, timestamp=101)
405+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
406+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
407+
index = IVFFlatIndex(uri=index_uri, timestamp=(0, 101))
408+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
409+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
410+
index = IVFFlatIndex(uri=index_uri, timestamp=(2, 101))
411+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
412+
assert 0.05 <= accuracy(result, gt_i, updated_ids=updated_ids, only_updated_ids=True) <= 0.15
413+
414+
# Timetravel with partial read from updates table
415+
updated_ids_part = {}
416+
for i in range(2, 52):
417+
updated_ids_part[i] = i + update_ids_offset
418+
index = IVFFlatIndex(uri=index_uri, timestamp=51)
419+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
420+
assert accuracy(result, gt_i, updated_ids=updated_ids_part) == 1.0
421+
index = IVFFlatIndex(uri=index_uri, timestamp=(0, 51))
422+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
423+
assert accuracy(result, gt_i, updated_ids=updated_ids_part) == 1.0
424+
index = IVFFlatIndex(uri=index_uri, timestamp=(2, 51))
425+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
426+
assert 0.02 <= accuracy(result, gt_i, updated_ids=updated_ids, only_updated_ids=True) <= 0.07
427+
428+
# Timetravel at previous ingestion timestamp
429+
index = IVFFlatIndex(uri=index_uri, timestamp=1)
430+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
431+
assert accuracy(result, gt_i) == 1.0
432+
433+
# Consolidate updates
434+
index = index.consolidate_updates()
435+
index = IVFFlatIndex(uri=index_uri, timestamp=101)
436+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
437+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
438+
index = IVFFlatIndex(uri=index_uri, timestamp=(0, 101))
439+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
440+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
441+
index = IVFFlatIndex(uri=index_uri, timestamp=(2, 101))
442+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
443+
assert 0.05 <= accuracy(result, gt_i, updated_ids=updated_ids, only_updated_ids=True) <= 0.15
444+
445+
# Timetravel with partial read from updates table
446+
updated_ids_part = {}
447+
for i in range(2, 52):
448+
updated_ids_part[i] = i + update_ids_offset
449+
index = IVFFlatIndex(uri=index_uri, timestamp=51)
450+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
451+
assert accuracy(result, gt_i, updated_ids=updated_ids_part) == 1.0
452+
index = IVFFlatIndex(uri=index_uri, timestamp=(0, 51))
453+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
454+
assert accuracy(result, gt_i, updated_ids=updated_ids_part) == 1.0
455+
index = IVFFlatIndex(uri=index_uri, timestamp=(2, 51))
456+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
457+
assert 0.02 <= accuracy(result, gt_i, updated_ids=updated_ids, only_updated_ids=True) <= 0.07
458+
459+
# Timetravel at previous ingestion timestamp
460+
index = IVFFlatIndex(uri=index_uri, timestamp=1)
461+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
462+
assert accuracy(result, gt_i) == 1.0
463+
index = IVFFlatIndex(uri=index_uri, timestamp=(0, 1))
464+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
465+
assert accuracy(result, gt_i) == 1.0

0 commit comments

Comments
 (0)