Skip to content

Commit d9cdbb7

Browse files
author
Nikos Papailiou
committed
Keep track of array sizes
1 parent a3e12aa commit d9cdbb7

File tree

12 files changed

+105
-27
lines changed

12 files changed

+105
-27
lines changed

apis/python/src/tiledb/vector_search/flat_index.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,15 @@ def __init__(
3131
schema = tiledb.ArraySchema.load(
3232
self.db_uri, ctx=tiledb.Ctx(self.config)
3333
)
34-
self.size = schema.domain.dim(1).domain[1]+1
34+
if self.base_size == -1:
35+
self.size = schema.domain.dim(1).domain[1] + 1
36+
else:
37+
self.size = self.base_size
3538
self._db = load_as_matrix(
3639
self.db_uri,
3740
ctx=self.ctx,
3841
config=config,
42+
size=self.size,
3943
timestamp=self.base_array_timestamp,
4044
)
4145
# Check for existence of ids array. Previous versions were not using external_ids in the ingestion assuming
@@ -44,8 +48,9 @@ def __init__(
4448
self.ids_uri = self.group[
4549
storage_formats[self.storage_version]["IDS_ARRAY_NAME"] + self.index_version
4650
].uri
47-
self._ids = read_vector_u64(self.ctx, self.ids_uri, 0, 0, self.base_array_timestamp)
51+
self._ids = read_vector_u64(self.ctx, self.ids_uri, 0, self.size, self.base_array_timestamp)
4852
else:
53+
self.ids_uri = ""
4954
self._ids = StdVector_u64(np.arange(self.size).astype(np.uint64))
5055

5156
dtype = self.group.meta.get("dtype", None)

apis/python/src/tiledb/vector_search/index.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,15 @@ def __init__(
4646
self.updates_array_uri = f"{self.group.uri}/{updates_array_name}"
4747
self.index_version = self.group.meta.get("index_version", "")
4848
self.ingestion_timestamps = list(json.loads(self.group.meta.get("ingestion_timestamps", "[]")))
49-
self.latest_ingestion_timestamp = self.ingestion_timestamps[len(self.ingestion_timestamps)-1]
49+
if len(self.ingestion_timestamps) > 0:
50+
self.latest_ingestion_timestamp = self.ingestion_timestamps[len(self.ingestion_timestamps)-1]
51+
else:
52+
self.latest_ingestion_timestamp = MAX_UINT64
53+
self.base_sizes = list(json.loads(self.group.meta.get("base_sizes", "[]")))
54+
if len(self.base_sizes) > 0:
55+
self.base_size = self.base_sizes[len(self.ingestion_timestamps)-1]
56+
else:
57+
self.base_size = -1
5058
self.base_array_timestamp = self.latest_ingestion_timestamp
5159
self.query_base_array = True
5260
self.update_array_timestamp = (self.base_array_timestamp+1, None)
@@ -59,15 +67,20 @@ def __init__(
5967
self.query_base_array = False
6068
self.update_array_timestamp = timestamp
6169
else:
70+
self.base_size = self.base_sizes[0]
6271
self.base_array_timestamp = self.ingestion_timestamps[0]
6372
self.update_array_timestamp = (self.base_array_timestamp+1, timestamp[1])
6473
else:
74+
self.base_size = self.base_sizes[0]
6575
self.base_array_timestamp = self.ingestion_timestamps[0]
6676
self.update_array_timestamp = (self.base_array_timestamp+1, timestamp[1])
6777
elif isinstance(timestamp, int):
78+
i = 0
6879
for ingestion_timestamp in self.ingestion_timestamps:
6980
if ingestion_timestamp <= timestamp:
7081
self.base_array_timestamp = ingestion_timestamp
82+
self.base_size = self.base_sizes[i]
83+
i += 1
7184
self.update_array_timestamp = (self.base_array_timestamp+1, timestamp)
7285
else:
7386
raise TypeError("Unexpected argument type for 'timestamp' keyword argument")
@@ -274,6 +287,7 @@ def consolidate_updates(self):
274287
size=self.size,
275288
source_uri=self.db_uri,
276289
external_ids_uri=self.ids_uri,
290+
external_ids_type="TILEDB_ARRAY",
277291
updates_uri=self.updates_array_uri,
278292
index_timestamp=max_timestamp,
279293
config=self.config,

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def create_arrays(
341341
logger.debug("Creating ids array")
342342
ids_array_rows_dim = tiledb.Dim(
343343
name="rows",
344-
domain=(0, size - 1),
344+
domain=(0, MAX_INT32),
345345
tile=int(size / partitions),
346346
dtype=np.dtype(np.int32),
347347
)
@@ -373,7 +373,7 @@ def create_arrays(
373373
)
374374
parts_array_cols_dim = tiledb.Dim(
375375
name="cols",
376-
domain=(0, size - 1),
376+
domain=(0, MAX_INT32),
377377
tile=int(size / partitions),
378378
dtype=np.dtype(np.int32),
379379
)
@@ -1784,6 +1784,7 @@ def consolidate_and_vacuum(
17841784
raise err
17851785
group = tiledb.Group(index_group_uri, "r")
17861786
ingestion_timestamps = list(json.loads(group.meta.get("ingestion_timestamps", "[]")))
1787+
base_sizes = list(json.loads(group.meta.get("base_sizes", "[]")))
17871788
if partitions == -1:
17881789
partitions = int(group.meta.get("partitions", "-1"))
17891790

@@ -1820,6 +1821,7 @@ def consolidate_and_vacuum(
18201821
size = in_size
18211822
if size > in_size:
18221823
size = in_size
1824+
base_sizes.append(size)
18231825
logger.debug("Input dataset size %d", size)
18241826
logger.debug("Input dataset dimensions %d", dimensions)
18251827
logger.debug("Vector dimension type %s", vector_type)
@@ -1840,6 +1842,7 @@ def consolidate_and_vacuum(
18401842
group.meta["partitions"] = partitions
18411843
group.meta["storage_version"] = STORAGE_VERSION
18421844
group.meta["ingestion_timestamps"] = json.dumps(ingestion_timestamps)
1845+
group.meta["base_sizes"] = json.dumps(base_sizes)
18431846

18441847
if external_ids is not None:
18451848
external_ids_uri = write_external_ids(

apis/python/src/tiledb/vector_search/ivf_flat_index.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,10 @@ def __init__(
7474
)
7575
self.partitions = schema.domain.dim("cols").domain[1] + 1
7676

77-
self.size = self._index[self.partitions]
78-
77+
if self.base_size == -1:
78+
self.size = self._index[self.partitions]
79+
else:
80+
self.size = self.base_size
7981

8082
# TODO pass in a context
8183
if self.memory_budget == -1:

apis/python/src/tiledb/vector_search/module.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ static void declareColMajorMatrixSubclass(py::module& mod,
310310
// TODO auto-namify
311311
PyTMatrix cls(mod, (name + suffix).c_str(), py::buffer_protocol());
312312

313-
cls.def(py::init<const Ctx&, std::string, size_t, uint64_t>(), py::keep_alive<1,2>());
313+
cls.def(py::init<const Ctx&, std::string, size_t, size_t, size_t, size_t, uint64_t>(), py::keep_alive<1,2>());
314314

315315
if constexpr (std::is_same<P, tdbColMajorMatrix<T>>::value) {
316316
cls.def("load", &TMatrix::load);

apis/python/src/tiledb/vector_search/module.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,15 @@ def load_as_matrix(
3737
a = tiledb.ArraySchema.load(path, ctx=tiledb.Ctx(config))
3838
dtype = a.attr(0).dtype
3939
if dtype == np.float32:
40-
m = tdbColMajorMatrix_f32(ctx, path, size, timestamp)
40+
m = tdbColMajorMatrix_f32(ctx, path, 0, 0, 0, size, timestamp)
4141
elif dtype == np.float64:
42-
m = tdbColMajorMatrix_f64(ctx, path, size, timestamp)
42+
m = tdbColMajorMatrix_f64(ctx, path, 0, 0, 0, size, timestamp)
4343
elif dtype == np.int32:
44-
m = tdbColMajorMatrix_i32(ctx, path, size, timestamp)
44+
m = tdbColMajorMatrix_i32(ctx, path, 0, 0, 0, size, timestamp)
4545
elif dtype == np.int32:
46-
m = tdbColMajorMatrix_i64(ctx, path, size, timestamp)
46+
m = tdbColMajorMatrix_i64(ctx, path, 0, 0, 0, size, timestamp)
4747
elif dtype == np.uint8:
48-
m = tdbColMajorMatrix_u8(ctx, path, size, timestamp)
48+
m = tdbColMajorMatrix_u8(ctx, path, 0, 0, 0, size, timestamp)
4949
# elif dtype == np.uint64:
5050
# return tdbColMajorMatrix_u64(ctx, path, size, timestamp)
5151
else:

apis/python/test/test_ingestion.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ def test_ivf_flat_ingestion_with_batch_updates(tmp_path):
370370
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
371371
assert accuracy(result, gt_i, updated_ids=updated_ids) > 0.99
372372

373+
373374
def test_ivf_flat_ingestion_with_updates_and_timetravel(tmp_path):
374375
dataset_dir = os.path.join(tmp_path, "dataset")
375376
index_uri = os.path.join(tmp_path, "array")
@@ -463,3 +464,42 @@ def test_ivf_flat_ingestion_with_updates_and_timetravel(tmp_path):
463464
index = IVFFlatIndex(uri=index_uri, timestamp=(0, 1))
464465
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
465466
assert accuracy(result, gt_i) == 1.0
467+
468+
469+
def test_ivf_flat_ingestion_with_additions_and_timetravel(tmp_path):
470+
dataset_dir = os.path.join(tmp_path, "dataset")
471+
index_uri = os.path.join(tmp_path, "array")
472+
k = 100
473+
size = 100
474+
partitions = 1
475+
dimensions = 128
476+
nqueries = 1
477+
nprobe = 1
478+
data = create_random_dataset_u8(nb=size, d=dimensions, nq=nqueries, k=k, path=dataset_dir)
479+
dtype = np.uint8
480+
481+
query_vectors = get_queries(dataset_dir, dtype=dtype)
482+
gt_i, gt_d = get_groundtruth(dataset_dir, k)
483+
index = ingest(
484+
index_type="IVF_FLAT",
485+
index_uri=index_uri,
486+
source_uri=os.path.join(dataset_dir, "data.u8bin"),
487+
partitions=partitions,
488+
index_timestamp=1,
489+
)
490+
_, result = index.query(query_vectors, k=k)
491+
assert accuracy(result, gt_i) == 1.0
492+
493+
update_ids_offset = MAX_UINT64-size
494+
updated_ids = {}
495+
for i in range(100):
496+
index.update(vector=data[i].astype(dtype), external_id=i + update_ids_offset, timestamp=i+2)
497+
updated_ids[i] = i + update_ids_offset
498+
499+
index = IVFFlatIndex(uri=index_uri)
500+
_, result = index.query(query_vectors, k=k)
501+
assert 0.45 < accuracy(result, gt_i) < 0.55
502+
503+
index = index.consolidate_updates()
504+
_, result = index.query(query_vectors, k=k)
505+
assert 0.45 < accuracy(result, gt_i) < 0.55

src/include/detail/ivf/dist_qv.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ auto dist_qv_finite_ram_part(
8080
if (nthreads == 0) {
8181
nthreads = std::thread::hardware_concurrency();
8282
}
83-
auto temporal_policy = (timestamp == 0) ? tiledb::TemporalPolicy() : tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp+1);
83+
auto temporal_policy = (timestamp == 0) ? tiledb::TemporalPolicy() : tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp);
8484

8585
using score_type = float;
8686
using parts_type =

src/include/detail/ivf/index.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ int ivf_index(
6969
if (nthreads == 0) {
7070
nthreads = std::thread::hardware_concurrency();
7171
}
72-
auto read_temporal_policy = (timestamp == 0) ? tiledb::TemporalPolicy() : tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp+1);
72+
auto read_temporal_policy = (timestamp == 0) ? tiledb::TemporalPolicy() : tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp);
7373
auto write_temporal_policy = (timestamp == 0) ? tiledb::TemporalPolicy() : tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp);
7474
auto centroids = tdbColMajorMatrix<centroids_type>(ctx, centroids_uri, 0, read_temporal_policy);
7575
centroids.load();
@@ -191,14 +191,14 @@ int ivf_index(
191191
size_t end_pos = 0,
192192
size_t nthreads = 0,
193193
uint64_t timestamp = 0) {
194-
auto db = tdbColMajorMatrix<T>(ctx, db_uri, 0, 0, start_pos, end_pos);
194+
auto db = tdbColMajorMatrix<T>(ctx, db_uri, 0, 0, start_pos, end_pos, timestamp);
195195
db.load();
196196
std::vector<ids_type> external_ids;
197197
if (external_ids_uri.empty()) {
198198
external_ids = std::vector<ids_type>(db.num_cols());
199199
std::iota(begin(external_ids), end(external_ids), start_pos);
200200
} else {
201-
auto temporal_policy = (timestamp == 0) ? tiledb::TemporalPolicy() : tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp+1);
201+
auto temporal_policy = (timestamp == 0) ? tiledb::TemporalPolicy() : tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp);
202202
external_ids =
203203
read_vector<ids_type>(ctx, external_ids_uri, start_pos, end_pos, temporal_policy);
204204
}
@@ -231,7 +231,7 @@ int ivf_index(
231231
size_t end_pos = 0,
232232
size_t nthreads = 0,
233233
uint64_t timestamp = 0) {
234-
auto temporal_policy = (timestamp == 0) ? tiledb::TemporalPolicy() : tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp+1);
234+
auto temporal_policy = (timestamp == 0) ? tiledb::TemporalPolicy() : tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp);
235235
auto db = tdbColMajorMatrix<T>(ctx, db_uri, 0, 0, start_pos, end_pos, temporal_policy);
236236
db.load();
237237
return ivf_index<T, ids_type, centroids_type>(
@@ -268,7 +268,7 @@ int ivf_index(
268268
external_ids = std::vector<ids_type>(db.num_cols());
269269
std::iota(begin(external_ids), end(external_ids), start_pos);
270270
} else {
271-
auto temporal_policy = (timestamp == 0) ? tiledb::TemporalPolicy() : tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp+1);
271+
auto temporal_policy = (timestamp == 0) ? tiledb::TemporalPolicy() : tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp);
272272
external_ids = read_vector<ids_type>(ctx, external_ids_uri, start_pos, end_pos, temporal_policy);
273273
}
274274
return ivf_index<T, ids_type, centroids_type>(

src/include/detail/ivf/qv.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ auto nuv_query_heap_infinite_ram_reg_blocked(
422422
size_t nthreads,
423423
uint64_t timestamp = 0) {
424424
scoped_timer _{tdb_func__};
425-
auto temporal_policy = (timestamp == 0) ? tiledb::TemporalPolicy() : tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp+1);
425+
auto temporal_policy = (timestamp == 0) ? tiledb::TemporalPolicy() : tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp);
426426

427427
// Read the shuffled database and ids
428428
// @todo To this more systematically
@@ -647,7 +647,7 @@ auto qv_query_heap_finite_ram(
647647
size_t nthreads,
648648
uint64_t timestamp = 0) {
649649
tiledb::Context ctx;
650-
auto temporal_policy = (timestamp == 0) ? tiledb::TemporalPolicy() : tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp+1);
650+
auto temporal_policy = (timestamp == 0) ? tiledb::TemporalPolicy() : tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp);
651651

652652
auto centroids = tdbColMajorMatrix<centroids_type>(ctx, centroids_uri, 0, temporal_policy);
653653
centroids.load();
@@ -722,7 +722,7 @@ auto qv_query_heap_finite_ram(
722722
size_t nthreads,
723723
uint64_t timestamp) {
724724
scoped_timer _{tdb_func__};
725-
auto temporal_policy = (timestamp == 0) ? tiledb::TemporalPolicy() : tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp+1);
725+
auto temporal_policy = (timestamp == 0) ? tiledb::TemporalPolicy() : tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp);
726726

727727
using score_type = float;
728728
using indices_type =
@@ -1194,7 +1194,7 @@ auto nuv_query_heap_finite_ram_reg_blocked(
11941194
size_t nthreads,
11951195
uint64_t timestamp = 0) {
11961196
scoped_timer _{tdb_func__ + " " + part_uri};
1197-
auto temporal_policy = (timestamp == 0) ? tiledb::TemporalPolicy() : tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp+1);
1197+
auto temporal_policy = (timestamp == 0) ? tiledb::TemporalPolicy() : tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp);
11981198

11991199
// Check that the size of the indices vector is correct
12001200
assert(size(indices) == centroids.num_cols() + 1);

0 commit comments

Comments
 (0)