Skip to content

Commit 23f5a51

Browse files
author
Nikos Papailiou
committed
Add update test
1 parent 6bb5d2b commit 23f5a51

File tree

2 files changed

+45
-4
lines changed

2 files changed

+45
-4
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ def query_additions(self, queries: np.ndarray, k):
103103

104104
def update(self, vector: np.array, external_id: np.uint64):
105105
updates_array = self.open_updates_array()
106-
updates_array[external_id] = vector
106+
vectors = np.empty((1), dtype='O')
107+
vectors[0] = vector
108+
updates_array[external_id] = {'vector': vectors}
107109
updates_array.close()
108110
self.consolidate_update_fragments()
109111

@@ -115,7 +117,9 @@ def update_batch(self, vectors: np.ndarray, external_ids: np.array):
115117

116118
def delete(self, external_id: np.uint64):
117119
updates_array = self.open_updates_array()
118-
updates_array[external_id] = np.array([], dtype=self.dtype)
120+
deletes = np.empty((1), dtype='O')
121+
deletes[0] = np.array([], dtype=self.dtype)
122+
updates_array[external_id] = {'vector': deletes}
119123
updates_array.close()
120124
self.consolidate_update_fragments()
121125

apis/python/test/test_ingestion.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,13 +307,50 @@ def test_ivf_flat_ingestion_with_updates(tmp_path):
307307
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
308308
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
309309

310+
updated_ids = {}
311+
for i in range(100):
312+
index.delete(external_id=i)
313+
index.update(vector=data[i].astype(dtype), external_id=i + 1000000)
314+
updated_ids[i + 1000000] = i
315+
316+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
317+
assert accuracy(result, gt_i, updated_ids=updated_ids) > MINIMUM_ACCURACY
318+
319+
index = index.consolidate_updates()
320+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
321+
assert accuracy(result, gt_i, updated_ids=updated_ids) > MINIMUM_ACCURACY
322+
323+
def test_ivf_flat_ingestion_with_batch_updates(tmp_path):
324+
dataset_dir = os.path.join(tmp_path, "dataset")
325+
index_uri = os.path.join(tmp_path, "array")
326+
k = 10
327+
size = 100000
328+
partitions = 100
329+
dimensions = 128
330+
nqueries = 100
331+
nprobe = 20
332+
data = create_random_dataset_u8(nb=size, d=dimensions, nq=nqueries, k=k, path=dataset_dir)
333+
dtype = np.uint8
334+
335+
query_vectors = get_queries(dataset_dir, dtype=dtype)
336+
gt_i, gt_d = get_groundtruth(dataset_dir, k)
337+
index = ingest(
338+
index_type="IVF_FLAT",
339+
index_uri=index_uri,
340+
source_uri=os.path.join(dataset_dir, "data.u8bin"),
341+
partitions=partitions,
342+
input_vectors_per_work_item=int(size / 10),
343+
)
344+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
345+
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
346+
310347
update_ids = {}
311348
updated_ids = {}
312349
for i in range(0, 100000, 2):
313350
update_ids[i] = i + 1000000
314351
updated_ids[i + 1000000] = i
315-
external_ids = np.zeros((len(update_ids)*2), dtype=np.uint64)
316-
updates = np.empty((len(update_ids)*2), dtype='O')
352+
external_ids = np.zeros((len(update_ids) * 2), dtype=np.uint64)
353+
updates = np.empty((len(update_ids) * 2), dtype='O')
317354
id = 0
318355
for prev_id, new_id in update_ids.items():
319356
external_ids[id] = prev_id

0 commit comments

Comments
 (0)