Skip to content

Commit abe48fe

Browse files
Merge pull request #170 from TileDB-Inc/npapa/fix-update-uri-bug
Fix bug for creating the updates array using a tiledb URI
2 parents f7b8c0b + 7bb5c08 commit abe48fe

File tree

4 files changed

+80
-10
lines changed

4 files changed

+80
-10
lines changed

apis/python/src/tiledb/vector_search/flat_index.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
validate_storage_version)
1111

1212
MAX_INT32 = np.iinfo(np.dtype("int32")).max
13+
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
1314
TILE_SIZE_BYTES = 128000000 # 128MB
1415
INDEX_TYPE = "FLAT"
1516

@@ -139,8 +140,10 @@ def create(
139140
tile_size = TILE_SIZE_BYTES / np.dtype(vector_type).itemsize / dimensions
140141
ids_array_name = storage_formats[storage_version]["IDS_ARRAY_NAME"]
141142
parts_array_name = storage_formats[storage_version]["PARTS_ARRAY_NAME"]
143+
updates_array_name = storage_formats[storage_version]["UPDATES_ARRAY_NAME"]
142144
ids_uri = f"{uri}/{ids_array_name}"
143145
parts_uri = f"{uri}/{parts_array_name}"
146+
updates_array_uri = f"{uri}/{updates_array_name}"
144147

145148
ids_array_rows_dim = tiledb.Dim(
146149
name="rows",
@@ -192,5 +195,21 @@ def create(
192195
tiledb.Array.create(parts_uri, parts_schema)
193196
group.add(parts_uri, name=parts_array_name)
194197

198+
external_id_dim = tiledb.Dim(
199+
name="external_id",
200+
domain=(0, MAX_UINT64 - 1),
201+
dtype=np.dtype(np.uint64),
202+
)
203+
dom = tiledb.Domain(external_id_dim)
204+
vector_attr = tiledb.Attr(name="vector", dtype=vector_type, var=True)
205+
updates_schema = tiledb.ArraySchema(
206+
domain=dom,
207+
sparse=True,
208+
attrs=[vector_attr],
209+
allows_duplicates=False,
210+
)
211+
tiledb.Array.create(updates_array_uri, updates_schema)
212+
group.add(updates_array_uri, name=updates_array_name)
213+
195214
group.close()
196215
return FlatIndex(uri=uri, config=config)

apis/python/src/tiledb/vector_search/index.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,11 @@ def __init__(
5252
raise ValueError(
5353
f"Time traveling is not supported for index storage_version={self.storage_version}"
5454
)
55-
5655
updates_array_name = storage_formats[self.storage_version]["UPDATES_ARRAY_NAME"]
57-
self.updates_array_uri = f"{self.group.uri}/{updates_array_name}"
56+
if updates_array_name in self.group:
57+
self.updates_array_uri = self.group[storage_formats[self.storage_version]["UPDATES_ARRAY_NAME"]].uri
58+
else:
59+
self.updates_array_uri = f"{self.group.uri}/{updates_array_name}"
5860
self.index_version = self.group.meta.get("index_version", "")
5961
self.ingestion_timestamps = [
6062
int(x)
@@ -134,7 +136,7 @@ def query(self, queries: np.ndarray, k, **kwargs):
134136
raise TypeError(f"A query in queries has {query_dimensions} dimensions, but the indexed data had {self.dimensions} dimensions")
135137

136138
with tiledb.scope_ctx(ctx_or_config=self.config):
137-
if not tiledb.array_exists(self.updates_array_uri):
139+
if not tiledb.array_exists(self.updates_array_uri) or not self.has_updates():
138140
if self.query_base_array:
139141
return self.query_internal(queries, k, **kwargs)
140142
else:
@@ -266,7 +268,22 @@ def get_dimensions(self):
266268
def query_internal(self, queries: np.ndarray, k, **kwargs):
267269
raise NotImplementedError
268270

271+
def has_updates(self):
272+
if "has_updates" in self.group.meta:
273+
return self.group.meta["has_updates"]
274+
else:
275+
return True
276+
277+
def set_has_updates(self, has_updates: bool = True):
278+
if not self.group.meta["has_updates"]:
279+
self.group.close()
280+
self.group = tiledb.Group(self.uri, "w", ctx=tiledb.Ctx(self.config))
281+
self.group.meta["has_updates"] = has_updates
282+
self.group.close()
283+
self.group = tiledb.Group(self.uri, "r", ctx=tiledb.Ctx(self.config))
284+
269285
def update(self, vector: np.array, external_id: np.uint64, timestamp: int = None):
286+
self.set_has_updates()
270287
updates_array = self.open_updates_array(timestamp=timestamp)
271288
vectors = np.empty((1), dtype="O")
272289
vectors[0] = vector
@@ -277,12 +294,14 @@ def update(self, vector: np.array, external_id: np.uint64, timestamp: int = None
277294
def update_batch(
278295
self, vectors: np.ndarray, external_ids: np.array, timestamp: int = None
279296
):
297+
self.set_has_updates()
280298
updates_array = self.open_updates_array(timestamp=timestamp)
281299
updates_array[external_ids] = {"vector": vectors}
282300
updates_array.close()
283301
self.consolidate_update_fragments()
284302

285303
def delete(self, external_id: np.uint64, timestamp: int = None):
304+
self.set_has_updates()
286305
updates_array = self.open_updates_array(timestamp=timestamp)
287306
deletes = np.empty((1), dtype="O")
288307
deletes[0] = np.array([], dtype=self.dtype)
@@ -291,6 +310,7 @@ def delete(self, external_id: np.uint64, timestamp: int = None):
291310
self.consolidate_update_fragments()
292311

293312
def delete_batch(self, external_ids: np.array, timestamp: int = None):
313+
self.set_has_updates()
294314
updates_array = self.open_updates_array(timestamp=timestamp)
295315
deletes = np.empty((len(external_ids)), dtype="O")
296316
for i in range(len(external_ids)):
@@ -529,4 +549,5 @@ def create_metadata(
529549
group.meta["index_type"] = index_type
530550
group.meta["base_sizes"] = json.dumps([0])
531551
group.meta["ingestion_timestamps"] = json.dumps([0])
552+
group.meta["has_updates"] = False
532553
group.close()

apis/python/src/tiledb/vector_search/ivf_flat_index.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
validate_storage_version)
1313

1414
MAX_INT32 = np.iinfo(np.dtype("int32")).max
15+
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
1516
TILE_SIZE_BYTES = 64000000 # 64MB
1617
INDEX_TYPE = "IVF_FLAT"
1718

@@ -436,12 +437,8 @@ def dist_qv_udf(
436437
tmp = sorted(tmp_results, key=lambda t: t[0])[0:k]
437438
for j in range(len(tmp), k):
438439
tmp.append((float(0.0), int(0)))
439-
results_per_query_d.append(
440-
np.array(tmp, dtype=np.dtype("float,uint64"))["f0"]
441-
)
442-
results_per_query_i.append(
443-
np.array(tmp, dtype=np.dtype("float,uint64"))["f1"]
444-
)
440+
results_per_query_d.append(np.array(tmp, dtype=np.float32)[:, 0])
441+
results_per_query_i.append(np.array(tmp, dtype=np.uint64)[:, 1])
445442
return np.array(results_per_query_d), np.array(results_per_query_i)
446443

447444

@@ -474,10 +471,12 @@ def create(
474471
index_array_name = storage_formats[storage_version]["INDEX_ARRAY_NAME"]
475472
ids_array_name = storage_formats[storage_version]["IDS_ARRAY_NAME"]
476473
parts_array_name = storage_formats[storage_version]["PARTS_ARRAY_NAME"]
474+
updates_array_name = storage_formats[storage_version]["UPDATES_ARRAY_NAME"]
477475
centroids_uri = f"{uri}/{centroids_array_name}"
478476
index_array_uri = f"{uri}/{index_array_name}"
479477
ids_uri = f"{uri}/{ids_array_name}"
480478
parts_uri = f"{uri}/{parts_array_name}"
479+
updates_array_uri = f"{uri}/{updates_array_name}"
481480

482481
centroids_array_rows_dim = tiledb.Dim(
483482
name="rows",
@@ -581,5 +580,21 @@ def create(
581580
tiledb.Array.create(parts_uri, parts_schema)
582581
group.add(parts_uri, name=parts_array_name)
583582

583+
external_id_dim = tiledb.Dim(
584+
name="external_id",
585+
domain=(0, MAX_UINT64 - 1),
586+
dtype=np.dtype(np.uint64),
587+
)
588+
dom = tiledb.Domain(external_id_dim)
589+
vector_attr = tiledb.Attr(name="vector", dtype=vector_type, var=True)
590+
updates_schema = tiledb.ArraySchema(
591+
domain=dom,
592+
sparse=True,
593+
attrs=[vector_attr],
594+
allows_duplicates=False,
595+
)
596+
tiledb.Array.create(updates_array_uri, updates_schema)
597+
group.add(updates_array_uri, name=updates_array_name)
598+
584599
group.close()
585600
return IVFFlatIndex(uri=uri, config=config, memory_budget=1000000)

apis/python/test/test_cloud.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,17 @@ def test_cloud_flat(self):
5050
config=tiledb.cloud.Config().dict(),
5151
mode=Mode.BATCH,
5252
)
53+
tiledb_index_uri = groups.info(index_uri).tiledb_uri
54+
index = vs.flat_index.FlatIndex(uri=tiledb_index_uri)
55+
5356
_, result_i = index.query(queries, k=k)
5457
assert accuracy(result_i, gt_i) > MINIMUM_ACCURACY
5558

59+
index.delete(external_id=42)
60+
_, result_i = index.query(queries, k=k)
61+
assert accuracy(result_i, gt_i) > MINIMUM_ACCURACY
62+
63+
5664
def test_cloud_ivf_flat(self):
5765
source_uri = "tiledb://TileDB-Inc/sift_10k"
5866
queries_uri = "test/data/siftsmall/siftsmall_query.fvecs"
@@ -79,6 +87,9 @@ def test_cloud_ivf_flat(self):
7987
# mode=Mode.BATCH,
8088
)
8189

90+
tiledb_index_uri = groups.info(index_uri).tiledb_uri
91+
index = vs.ivf_flat_index.IVFFlatIndex(uri=tiledb_index_uri)
92+
8293
_, result_i = index.query(queries, k=k, nprobe=nprobe)
8394
assert accuracy(result_i, gt_i) > MINIMUM_ACCURACY
8495

@@ -113,4 +124,8 @@ def test_cloud_ivf_flat(self):
113124
with self.assertRaises(TypeError):
114125
index.query(queries, k=k, nprobe=nprobe, mode=Mode.REALTIME, resource_class="large", resources=resources)
115126
with self.assertRaises(TypeError):
116-
index.query(queries, k=k, nprobe=nprobe, mode=Mode.BATCH, resource_class="large", resources=resources)
127+
index.query(queries, k=k, nprobe=nprobe, mode=Mode.BATCH, resource_class="large", resources=resources)
128+
129+
index.delete(external_id=42)
130+
_, result_i = index.query(queries, k=k, nprobe=nprobe)
131+
assert accuracy(result_i, gt_i) > MINIMUM_ACCURACY

0 commit comments

Comments
 (0)