Skip to content

Commit e5dbb5e

Browse files
author
Nikos Papailiou
committed
Timetravel implementation
1 parent ba91a70 commit e5dbb5e

File tree

14 files changed

+303
-164
lines changed

14 files changed

+303
-164
lines changed

apis/python/src/tiledb/vector_search/flat_index.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ def __init__(
2222
self,
2323
uri: str,
2424
config: Optional[Mapping[str, Any]] = None,
25+
timestamp=None,
2526
):
26-
super().__init__(uri=uri, config=config)
27+
super().__init__(uri=uri, config=config, timestamp=timestamp)
2728
self.index_type = "FLAT"
2829
self._index = None
2930
self.db_uri = self.group[storage_formats[self.storage_version]["PARTS_ARRAY_NAME"] + self.index_version].uri
@@ -35,12 +36,13 @@ def __init__(
3536
self.db_uri,
3637
ctx=self.ctx,
3738
config=config,
39+
timestamp=self.base_array_timestamp,
3840
)
3941
self.ids_uri = self.group[
4042
storage_formats[self.storage_version]["IDS_ARRAY_NAME"] + self.index_version
4143
].uri
4244
if tiledb.array_exists(self.ids_uri, self.ctx):
43-
self._ids = read_vector_u64(self.ctx, self.ids_uri, 0, 0)
45+
self._ids = read_vector_u64(self.ctx, self.ids_uri, 0, 0, self.base_array_timestamp)
4446
else:
4547
self._ids = StdVector_u64(np.arange(self.size).astype(np.uint64))
4648

apis/python/src/tiledb/vector_search/index.py

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import concurrent.futures as futures
2+
import json
23
import os
34
import numpy as np
45
import sys
6+
import time
57

68
from tiledb.vector_search.module import *
79
from tiledb.vector_search.storage_formats import storage_formats
@@ -27,6 +29,7 @@ def __init__(
2729
self,
2830
uri: str,
2931
config: Optional[Mapping[str, Any]] = None,
32+
timestamp: int = None,
3033
):
3134
# If the user passes a tiledb python Config object convert to a dictionary
3235
if isinstance(config, tiledb.Config):
@@ -39,6 +42,36 @@ def __init__(
3942
self.storage_version = self.group.meta.get("storage_version", "0.1")
4043
self.update_arrays_uri = None
4144
self.index_version = self.group.meta.get("index_version", "")
45+
46+
self.ingestion_timestamps = list(json.loads(self.group.meta.get("ingestion_timestamps", "[]")))
47+
print(f"ingestion_timestamps: {self.ingestion_timestamps}")
48+
self.base_array_timestamp = self.ingestion_timestamps[len(self.ingestion_timestamps)-1]
49+
print(f"base_array_timestamp: {self.base_array_timestamp}")
50+
self.query_base_array = True
51+
self.update_array_timestamp = (self.base_array_timestamp+1, None)
52+
if timestamp is not None:
53+
if isinstance(timestamp, tuple):
54+
if len(timestamp) != 2:
55+
raise ValueError("'timestamp' argument expects either int or tuple(start: int, end: int)")
56+
if timestamp[0] is not None:
57+
if timestamp[0] > self.ingestion_timestamps[0]:
58+
self.query_base_array = False
59+
self.update_array_timestamp = timestamp
60+
else:
61+
self.base_array_timestamp = self.ingestion_timestamps[0]
62+
self.update_array_timestamp = (self.base_array_timestamp+1, timestamp[1])
63+
else:
64+
self.base_array_timestamp = self.ingestion_timestamps[0]
65+
self.update_array_timestamp = (self.base_array_timestamp+1, timestamp[1])
66+
elif isinstance(timestamp, int):
67+
for ingestion_timestamp in self.ingestion_timestamps:
68+
if ingestion_timestamp <= timestamp:
69+
self.base_array_timestamp = ingestion_timestamp
70+
self.update_array_timestamp = (self.base_array_timestamp+1, timestamp)
71+
else:
72+
raise TypeError("Unexpected argument type for 'timestamp' keyword argument")
73+
print(f"base_array_timestamp: {self.base_array_timestamp}")
74+
print(f"update_array_timestamp: {self.update_array_timestamp}")
4275
self.thread_executor = futures.ThreadPoolExecutor()
4376

4477
def query(self, queries: np.ndarray, k, **kwargs):
@@ -56,6 +89,7 @@ def query(self, queries: np.ndarray, k, **kwargs):
5689
self.dtype,
5790
self.update_arrays_uri,
5891
int(os.cpu_count() / 2),
92+
self.update_array_timestamp,
5993
)
6094
internal_results_d, internal_results_i = self.query_internal(
6195
queries, retrieval_k, **kwargs
@@ -108,11 +142,11 @@ def query(self, queries: np.ndarray, k, **kwargs):
108142

109143
@staticmethod
110144
def query_additions(
111-
queries: np.ndarray, k, dtype, update_arrays_uri, nthreads=8
145+
queries: np.ndarray, k, dtype, update_arrays_uri, nthreads=8, timestamp=None
112146
):
113147
assert queries.dtype == np.float32
114148
additions_vectors, additions_external_ids, updated_ids = Index.read_additions(
115-
update_arrays_uri
149+
update_arrays_uri, timestamp
116150
)
117151
if additions_vectors is None:
118152
return None, None, updated_ids
@@ -128,10 +162,10 @@ def query_additions(
128162
return np.transpose(np.array(d)), np.transpose(np.array(i)), updated_ids
129163

130164
@staticmethod
131-
def read_additions(update_arrays_uri) -> (np.ndarray, np.array):
165+
def read_additions(update_arrays_uri, timestamp=None) -> (np.ndarray, np.array):
132166
if update_arrays_uri is None:
133167
return None, None, np.array([], np.uint64)
134-
updates_array = tiledb.open(update_arrays_uri, mode="r")
168+
updates_array = tiledb.open(update_arrays_uri, mode="r", timestamp=timestamp)
135169
q = updates_array.query(attrs=("vector",), coords=True)
136170
data = q[:]
137171
updates_array.close()
@@ -149,30 +183,30 @@ def read_additions(update_arrays_uri) -> (np.ndarray, np.array):
149183
def query_internal(self, queries: np.ndarray, k, **kwargs):
150184
raise NotImplementedError
151185

152-
def update(self, vector: np.array, external_id: np.uint64):
153-
updates_array = self.open_updates_array()
186+
def update(self, vector: np.array, external_id: np.uint64, timestamp: int = None):
187+
updates_array = self.open_updates_array(timestamp=timestamp)
154188
vectors = np.empty((1), dtype="O")
155189
vectors[0] = vector
156190
updates_array[external_id] = {"vector": vectors}
157191
updates_array.close()
158192
self.consolidate_update_fragments()
159193

160-
def update_batch(self, vectors: np.ndarray, external_ids: np.array):
161-
updates_array = self.open_updates_array()
194+
def update_batch(self, vectors: np.ndarray, external_ids: np.array, timestamp: int = None):
195+
updates_array = self.open_updates_array(timestamp=timestamp)
162196
updates_array[external_ids] = {"vector": vectors}
163197
updates_array.close()
164198
self.consolidate_update_fragments()
165199

166-
def delete(self, external_id: np.uint64):
167-
updates_array = self.open_updates_array()
200+
def delete(self, external_id: np.uint64, timestamp: int = None):
201+
updates_array = self.open_updates_array(timestamp=timestamp)
168202
deletes = np.empty((1), dtype="O")
169203
deletes[0] = np.array([], dtype=self.dtype)
170204
updates_array[external_id] = {"vector": deletes}
171205
updates_array.close()
172206
self.consolidate_update_fragments()
173207

174-
def delete_batch(self, external_ids: np.array):
175-
updates_array = self.open_updates_array()
208+
def delete_batch(self, external_ids: np.array, timestamp: int = None):
209+
updates_array = self.open_updates_array(timestamp=timestamp)
176210
deletes = np.empty((len(external_ids)), dtype="O")
177211
for i in range(len(external_ids)):
178212
deletes[i] = np.array([], dtype=self.dtype)
@@ -189,7 +223,7 @@ def consolidate_update_fragments(self):
189223
def get_updates_uri(self):
190224
return self.update_arrays_uri
191225

192-
def open_updates_array(self):
226+
def open_updates_array(self, timestamp: int = None):
193227
if self.update_arrays_uri is None:
194228
updates_array_name = storage_formats[self.storage_version][
195229
"UPDATES_ARRAY_NAME"
@@ -217,7 +251,9 @@ def open_updates_array(self):
217251
self.group.close()
218252
self.group = tiledb.Group(self.uri, "r", ctx=tiledb.Ctx(self.config))
219253
self.update_arrays_uri = updates_array_uri
220-
return tiledb.open(self.update_arrays_uri, mode="w")
254+
if timestamp is None:
255+
timestamp = int(time.time() * 1000)
256+
return tiledb.open(self.update_arrays_uri, mode="w", timestamp=timestamp)
221257

222258
def consolidate_updates(self):
223259
from tiledb.vector_search.ingestion import ingest
@@ -230,9 +266,5 @@ def consolidate_updates(self):
230266
external_ids_uri=self.ids_uri,
231267
updates_uri=self.update_arrays_uri,
232268
)
233-
tiledb.Array.delete_array(self.update_arrays_uri)
234-
self.group.close()
235-
self.group = tiledb.Group(self.uri, "w", ctx=tiledb.Ctx(self.config))
236-
self.group.remove(self.update_arrays_uri)
237-
self.group.close()
269+
new_index.update_arrays_uri = self.update_arrays_uri
238270
return new_index

0 commit comments

Comments
 (0)