Skip to content

Commit eb8a28b

Browse files
Merge pull request #148 from TileDB-Inc/npapa/timetravel
Timetravel implementation
2 parents 39f94c3 + 18e4818 commit eb8a28b

File tree

20 files changed

+872
-285
lines changed

20 files changed

+872
-285
lines changed

apis/python/src/tiledb/vector_search/flat_index.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,29 +22,35 @@ def __init__(
2222
self,
2323
uri: str,
2424
config: Optional[Mapping[str, Any]] = None,
25+
timestamp=None,
2526
):
26-
super().__init__(uri=uri, config=config)
27+
super().__init__(uri=uri, config=config, timestamp=timestamp)
2728
self.index_type = "FLAT"
2829
self._index = None
2930
self.db_uri = self.group[storage_formats[self.storage_version]["PARTS_ARRAY_NAME"] + self.index_version].uri
3031
schema = tiledb.ArraySchema.load(
3132
self.db_uri, ctx=tiledb.Ctx(self.config)
3233
)
33-
self.size = schema.domain.dim(1).domain[1]+1
34+
if self.base_size == -1:
35+
self.size = schema.domain.dim(1).domain[1] + 1
36+
else:
37+
self.size = self.base_size
3438
self._db = load_as_matrix(
3539
self.db_uri,
3640
ctx=self.ctx,
3741
config=config,
42+
size=self.size,
43+
timestamp=self.base_array_timestamp,
3844
)
39-
4045
# Check for existence of ids array. Previous versions were not using external_ids in the ingestion assuming
4146
# that the external_ids were the position of the vector in the array.
4247
if storage_formats[self.storage_version]["IDS_ARRAY_NAME"] + self.index_version in self.group:
4348
self.ids_uri = self.group[
4449
storage_formats[self.storage_version]["IDS_ARRAY_NAME"] + self.index_version
4550
].uri
46-
self._ids = read_vector_u64(self.ctx, self.ids_uri, 0, 0)
51+
self._ids = read_vector_u64(self.ctx, self.ids_uri, 0, self.size, self.base_array_timestamp)
4752
else:
53+
self.ids_uri = ""
4854
self._ids = StdVector_u64(np.arange(self.size).astype(np.uint64))
4955

5056
dtype = self.group.meta.get("dtype", None)

apis/python/src/tiledb/vector_search/index.py

Lines changed: 180 additions & 40 deletions
Large diffs are not rendered by default.

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 152 additions & 75 deletions
Large diffs are not rendered by default.

apis/python/src/tiledb/vector_search/ivf_flat_index.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ def __init__(
3333
self,
3434
uri: str,
3535
config: Optional[Mapping[str, Any]] = None,
36+
timestamp=None,
3637
memory_budget: int = -1,
3738
):
38-
super().__init__(uri=uri, config=config)
39+
super().__init__(uri=uri, config=config, timestamp=timestamp)
3940
self.index_type = "IVF_FLAT"
4041
self.db_uri = self.group[
4142
storage_formats[self.storage_version]["PARTS_ARRAY_NAME"] + self.index_version
@@ -52,9 +53,9 @@ def __init__(
5253
self.memory_budget = memory_budget
5354

5455
self._centroids = load_as_matrix(
55-
self.centroids_uri, ctx=self.ctx, config=config
56+
self.centroids_uri, ctx=self.ctx, config=config, timestamp=self.base_array_timestamp
5657
)
57-
self._index = read_vector_u64(self.ctx, self.index_array_uri, 0, 0)
58+
self._index = read_vector_u64(self.ctx, self.index_array_uri, 0, 0, self.base_array_timestamp)
5859

5960

6061
dtype = self.group.meta.get("dtype", None)
@@ -73,13 +74,15 @@ def __init__(
7374
)
7475
self.partitions = schema.domain.dim("cols").domain[1] + 1
7576

76-
self.size = self._index[self.partitions]
77-
77+
if self.base_size == -1:
78+
self.size = self._index[self.partitions]
79+
else:
80+
self.size = self.base_size
7881

7982
# TODO pass in a context
8083
if self.memory_budget == -1:
81-
self._db = load_as_matrix(self.db_uri, ctx=self.ctx, config=config, size=self.size)
82-
self._ids = read_vector_u64(self.ctx, self.ids_uri, 0, self.size)
84+
self._db = load_as_matrix(self.db_uri, ctx=self.ctx, config=config, size=self.size, timestamp=self.base_array_timestamp)
85+
self._ids = read_vector_u64(self.ctx, self.ids_uri, 0, self.size, self.base_array_timestamp)
8386

8487
def query_internal(
8588
self,
@@ -157,6 +160,7 @@ def query_internal(
157160
nthreads=nthreads,
158161
ctx=self.ctx,
159162
use_nuv_implementation=use_nuv_implementation,
163+
timestamp=self.base_array_timestamp,
160164
)
161165

162166
return np.transpose(np.array(d)), np.transpose(np.array(i))
@@ -229,19 +233,34 @@ def dist_qv_udf(
229233
indices: np.array,
230234
k_nn: int,
231235
config: Optional[Mapping[str, Any]] = None,
236+
timestamp: int = 0,
232237
):
233238
queries_m = array_to_matrix(np.transpose(query_vectors))
234-
r = dist_qv(
235-
dtype=dtype,
236-
parts_uri=parts_uri,
237-
ids_uri=ids_uri,
238-
query_vectors=queries_m,
239-
active_partitions=active_partitions,
240-
active_queries=active_queries,
241-
indices=indices,
242-
k_nn=k_nn,
243-
ctx=Ctx(config),
244-
)
239+
if timestamp == 0:
240+
r = dist_qv(
241+
dtype=dtype,
242+
parts_uri=parts_uri,
243+
ids_uri=ids_uri,
244+
query_vectors=queries_m,
245+
active_partitions=active_partitions,
246+
active_queries=active_queries,
247+
indices=indices,
248+
k_nn=k_nn,
249+
ctx=Ctx(config),
250+
)
251+
else:
252+
r = dist_qv(
253+
dtype=dtype,
254+
parts_uri=parts_uri,
255+
ids_uri=ids_uri,
256+
query_vectors=queries_m,
257+
active_partitions=active_partitions,
258+
active_queries=active_queries,
259+
indices=indices,
260+
k_nn=k_nn,
261+
ctx=Ctx(config),
262+
timestamp=timestamp,
263+
)
245264
results = []
246265
for q in range(len(r)):
247266
tmp_results = []
@@ -308,6 +327,7 @@ def dist_qv_udf(
308327
indices=np.array(self._index),
309328
k_nn=k,
310329
config=config,
330+
timestamp=self.base_array_timestamp,
311331
resource_class="large",
312332
image_name="3.9-vectorsearch",
313333
)

apis/python/src/tiledb/vector_search/module.cc

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ static void declare_qv_query_heap_finite_ram(py::module& m, const std::string& s
150150
size_t nprobe,
151151
size_t k_nn,
152152
size_t upper_bound,
153-
size_t nthreads) -> py::tuple { //std::tuple<ColMajorMatrix<float>, ColMajorMatrix<size_t>> { // TODO change return type
153+
size_t nthreads,
154+
uint64_t timestamp) -> py::tuple { //std::tuple<ColMajorMatrix<float>, ColMajorMatrix<size_t>> { // TODO change return type
154155

155156
auto r = detail::ivf::qv_query_heap_finite_ram<T, Id_Type>(
156157
ctx,
@@ -162,7 +163,8 @@ static void declare_qv_query_heap_finite_ram(py::module& m, const std::string& s
162163
nprobe,
163164
k_nn,
164165
upper_bound,
165-
nthreads);
166+
nthreads,
167+
timestamp);
166168
return make_python_pair(std::move(r));
167169
}, py::keep_alive<1,2>());
168170
}
@@ -204,7 +206,8 @@ static void declare_nuv_query_heap_finite_ram(py::module& m, const std::string&
204206
size_t nprobe,
205207
size_t k_nn,
206208
size_t upper_bound,
207-
size_t nthreads) -> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<uint64_t>> { // TODO change return type
209+
size_t nthreads,
210+
uint64_t timestamp) -> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<uint64_t>> { // TODO change return type
208211

209212
auto r = detail::ivf::nuv_query_heap_finite_ram_reg_blocked<T, Id_Type>(
210213
ctx,
@@ -216,7 +219,8 @@ static void declare_nuv_query_heap_finite_ram(py::module& m, const std::string&
216219
nprobe,
217220
k_nn,
218221
upper_bound,
219-
nthreads);
222+
nthreads,
223+
timestamp);
220224
return r;
221225
}, py::keep_alive<1,2>());
222226
}
@@ -234,7 +238,8 @@ static void declare_ivf_index(py::module& m, const std::string& suffix) {
234238
const std::string& id_uri,
235239
size_t start_pos,
236240
size_t end_pos,
237-
size_t nthreads) -> int {
241+
size_t nthreads,
242+
uint64_t timestamp) -> int {
238243
return detail::ivf::ivf_index<T, uint64_t, float>(
239244
ctx,
240245
db,
@@ -246,7 +251,8 @@ static void declare_ivf_index(py::module& m, const std::string& suffix) {
246251
id_uri,
247252
start_pos,
248253
end_pos,
249-
nthreads);
254+
nthreads,
255+
timestamp);
250256
}, py::keep_alive<1,2>());
251257
}
252258

@@ -263,7 +269,8 @@ static void declare_ivf_index_tdb(py::module& m, const std::string& suffix) {
263269
const std::string& id_uri,
264270
size_t start_pos,
265271
size_t end_pos,
266-
size_t nthreads) -> int {
272+
size_t nthreads,
273+
uint64_t timestamp) -> int {
267274
return detail::ivf::ivf_index<T, uint64_t, float>(
268275
ctx,
269276
db_uri,
@@ -275,7 +282,8 @@ static void declare_ivf_index_tdb(py::module& m, const std::string& suffix) {
275282
id_uri,
276283
start_pos,
277284
end_pos,
278-
nthreads);
285+
nthreads,
286+
timestamp);
279287
}, py::keep_alive<1,2>());
280288
}
281289

@@ -302,7 +310,7 @@ static void declareColMajorMatrixSubclass(py::module& mod,
302310
// TODO auto-namify
303311
PyTMatrix cls(mod, (name + suffix).c_str(), py::buffer_protocol());
304312

305-
cls.def(py::init<const Ctx&, std::string, size_t>(), py::keep_alive<1,2>());
313+
cls.def(py::init<const Ctx&, std::string, size_t, size_t, size_t, size_t, uint64_t>(), py::keep_alive<1,2>());
306314

307315
if constexpr (std::is_same<P, tdbColMajorMatrix<T>>::value) {
308316
cls.def("load", &TMatrix::load);
@@ -378,7 +386,8 @@ static void declare_dist_qv(py::module& m, const std::string& suffix) {
378386
std::vector<std::vector<int>>& active_queries,
379387
std::vector<shuffled_ids_type>& indices,
380388
const std::string& id_uri,
381-
size_t k_nn
389+
size_t k_nn,
390+
uint64_t timestamp
382391
/* size_t nthreads TODO: optional arg w/ fallback to C++ default arg */
383392
) { /* TODO return type */
384393
return detail::ivf::dist_qv_finite_ram_part<T, shuffled_ids_type>(
@@ -389,7 +398,8 @@ static void declare_dist_qv(py::module& m, const std::string& suffix) {
389398
active_queries,
390399
indices,
391400
id_uri,
392-
k_nn);
401+
k_nn,
402+
timestamp);
393403
}, py::keep_alive<1,2>());
394404
}
395405

@@ -448,8 +458,26 @@ PYBIND11_MODULE(_tiledbvspy, m) {
448458
declareStdVector<size_t>(m, "szt");
449459
}
450460

451-
m.def("read_vector_u32", &read_vector<uint32_t>, "Read a vector from TileDB");
452-
m.def("read_vector_u64", &read_vector<uint64_t>, "Read a vector from TileDB");
461+
m.def("read_vector_u32",
462+
[](const tiledb::Context& ctx,
463+
const std::string& uri,
464+
size_t start_pos,
465+
size_t end_pos,
466+
uint64_t timestamp) -> std::vector<uint32_t> {
467+
auto r = read_vector<uint32_t>(ctx, uri, start_pos, end_pos, timestamp);
468+
return r;
469+
});
470+
m.def("read_vector_u64",
471+
[](const tiledb::Context& ctx,
472+
const std::string& uri,
473+
size_t start_pos,
474+
size_t end_pos,
475+
uint64_t timestamp) -> std::vector<uint64_t> {
476+
auto r = read_vector<uint64_t>(ctx, uri, start_pos, end_pos, timestamp);
477+
return r;
478+
});
479+
// m.def("read_vector_u32", &read_vector<uint32_t>, "Read a vector from TileDB");
480+
// m.def("read_vector_u64", &read_vector<uint64_t>, "Read a vector from TileDB");
453481

454482
m.def("_create_vector_u64", []() {
455483
auto v = std::vector<uint64_t>(10);

apis/python/src/tiledb/vector_search/module.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ def load_as_matrix(
1212
path: str,
1313
ctx: "Ctx" = None,
1414
config: Optional[Mapping[str, Any]] = None,
15-
size: int = 0
15+
size: int = 0,
16+
timestamp: int = 0,
1617
):
1718
"""
1819
Load array as Matrix class
@@ -36,17 +37,17 @@ def load_as_matrix(
3637
a = tiledb.ArraySchema.load(path, ctx=tiledb.Ctx(config))
3738
dtype = a.attr(0).dtype
3839
if dtype == np.float32:
39-
m = tdbColMajorMatrix_f32(ctx, path, size)
40+
m = tdbColMajorMatrix_f32(ctx, path, 0, 0, 0, size, timestamp)
4041
elif dtype == np.float64:
41-
m = tdbColMajorMatrix_f64(ctx, path, size)
42+
m = tdbColMajorMatrix_f64(ctx, path, 0, 0, 0, size, timestamp)
4243
elif dtype == np.int32:
43-
m = tdbColMajorMatrix_i32(ctx, path, size)
44+
m = tdbColMajorMatrix_i32(ctx, path, 0, 0, 0, size, timestamp)
4445
elif dtype == np.int32:
45-
m = tdbColMajorMatrix_i64(ctx, path, size)
46+
m = tdbColMajorMatrix_i64(ctx, path, 0, 0, 0, size, timestamp)
4647
elif dtype == np.uint8:
47-
m = tdbColMajorMatrix_u8(ctx, path, size)
48+
m = tdbColMajorMatrix_u8(ctx, path, 0, 0, 0, size, timestamp)
4849
# elif dtype == np.uint64:
49-
# return tdbColMajorMatrix_u64(ctx, path, size)
50+
# return tdbColMajorMatrix_u64(ctx, path, size, timestamp)
5051
else:
5152
raise ValueError("Unsupported Matrix dtype: {}".format(a.attr(0).dtype))
5253
m.load()
@@ -148,6 +149,7 @@ def ivf_index_tdb(
148149
start: int = 0,
149150
end: int = 0,
150151
nthreads: int = 0,
152+
timestamp: int = 0,
151153
config: Dict = None,
152154
):
153155
if config is None:
@@ -156,7 +158,7 @@ def ivf_index_tdb(
156158
ctx = Ctx(config)
157159

158160
args = tuple(
159-
[ctx, db_uri, external_ids_uri, deleted_ids, centroids_uri, parts_uri, index_array_uri, id_uri, start, end, nthreads]
161+
[ctx, db_uri, external_ids_uri, deleted_ids, centroids_uri, parts_uri, index_array_uri, id_uri, start, end, nthreads, timestamp]
160162
)
161163

162164
if dtype == np.float32:
@@ -179,6 +181,7 @@ def ivf_index(
179181
start: int = 0,
180182
end: int = 0,
181183
nthreads: int = 0,
184+
timestamp: int = 0,
182185
config: Dict = None,
183186
):
184187
if config is None:
@@ -187,7 +190,7 @@ def ivf_index(
187190
ctx = Ctx(config)
188191

189192
args = tuple(
190-
[ctx, db, external_ids, deleted_ids, centroids_uri, parts_uri, index_array_uri, id_uri, start, end, nthreads]
193+
[ctx, db, external_ids, deleted_ids, centroids_uri, parts_uri, index_array_uri, id_uri, start, end, nthreads, timestamp]
191194
)
192195

193196
if dtype == np.float32:
@@ -280,6 +283,7 @@ def ivf_query(
280283
nthreads: int,
281284
ctx: "Ctx" = None,
282285
use_nuv_implementation: bool = False,
286+
timestamp: int = 0,
283287
):
284288
"""
285289
Run IVF vector query using a memory budget
@@ -308,6 +312,8 @@ def ivf_query(
308312
Number of theads
309313
ctx: Ctx
310314
Tiledb Context
315+
timestamp: int
316+
Read timestamp
311317
"""
312318
if ctx is None:
313319
ctx = Ctx({})
@@ -324,6 +330,7 @@ def ivf_query(
324330
k_nn,
325331
memory_budget,
326332
nthreads,
333+
timestamp,
327334
]
328335
)
329336

@@ -360,6 +367,7 @@ def dist_qv(
360367
indices: np.array,
361368
k_nn: int,
362369
ctx: "Ctx" = None,
370+
timestamp: int = 0,
363371
):
364372
if ctx is None:
365373
ctx = Ctx({})
@@ -373,6 +381,7 @@ def dist_qv(
373381
StdVector_u64(indices),
374382
ids_uri,
375383
k_nn,
384+
timestamp,
376385
]
377386
)
378387
if dtype == np.float32:

0 commit comments

Comments
 (0)