@@ -307,13 +307,50 @@ def test_ivf_flat_ingestion_with_updates(tmp_path):
307307 _ , result = index .query (query_vectors , k = k , nprobe = nprobe )
308308 assert accuracy (result , gt_i ) > MINIMUM_ACCURACY
309309
310+ updated_ids = {}
311+ for i in range (100 ):
312+ index .delete (external_id = i )
313+ index .update (vector = data [i ].astype (dtype ), external_id = i + 1000000 )
314+ updated_ids [i + 1000000 ] = i
315+
316+ _ , result = index .query (query_vectors , k = k , nprobe = nprobe )
317+ assert accuracy (result , gt_i , updated_ids = updated_ids ) > MINIMUM_ACCURACY
318+
319+ index = index .consolidate_updates ()
320+ _ , result = index .query (query_vectors , k = k , nprobe = nprobe )
321+ assert accuracy (result , gt_i , updated_ids = updated_ids ) > MINIMUM_ACCURACY
322+
323+ def test_ivf_flat_ingestion_with_batch_updates (tmp_path ):
324+ dataset_dir = os .path .join (tmp_path , "dataset" )
325+ index_uri = os .path .join (tmp_path , "array" )
326+ k = 10
327+ size = 100000
328+ partitions = 100
329+ dimensions = 128
330+ nqueries = 100
331+ nprobe = 20
332+ data = create_random_dataset_u8 (nb = size , d = dimensions , nq = nqueries , k = k , path = dataset_dir )
333+ dtype = np .uint8
334+
335+ query_vectors = get_queries (dataset_dir , dtype = dtype )
336+ gt_i , gt_d = get_groundtruth (dataset_dir , k )
337+ index = ingest (
338+ index_type = "IVF_FLAT" ,
339+ index_uri = index_uri ,
340+ source_uri = os .path .join (dataset_dir , "data.u8bin" ),
341+ partitions = partitions ,
342+ input_vectors_per_work_item = int (size / 10 ),
343+ )
344+ _ , result = index .query (query_vectors , k = k , nprobe = nprobe )
345+ assert accuracy (result , gt_i ) > MINIMUM_ACCURACY
346+
310347 update_ids = {}
311348 updated_ids = {}
312349 for i in range (0 , 100000 , 2 ):
313350 update_ids [i ] = i + 1000000
314351 updated_ids [i + 1000000 ] = i
315- external_ids = np .zeros ((len (update_ids )* 2 ), dtype = np .uint64 )
316- updates = np .empty ((len (update_ids )* 2 ), dtype = 'O' )
352+ external_ids = np .zeros ((len (update_ids ) * 2 ), dtype = np .uint64 )
353+ updates = np .empty ((len (update_ids ) * 2 ), dtype = 'O' )
317354 id = 0
318355 for prev_id , new_id in update_ids .items ():
319356 external_ids [id ] = prev_id
0 commit comments