@@ -365,7 +365,7 @@ def test_ivf_flat_ingestion_with_updates(tmp_path):
365365 _ , result = index .query (queries , k = k , nprobe = nprobe )
366366 assert accuracy (result , gt_i , updated_ids = updated_ids ) == 1.0
367367
368- index = index .consolidate_updates (partitions = 20 )
368+ index = index .consolidate_updates (retrain_index = True , partitions = 20 )
369369 _ , result = index .query (queries , k = k , nprobe = 20 )
370370 assert accuracy (result , gt_i , updated_ids = updated_ids ) == 1.0
371371
@@ -733,7 +733,7 @@ def test_storage_versions(tmp_path):
733733 _ , result = index .query (queries , k = k )
734734 assert accuracy (result , gt_i , updated_ids = updated_ids ) >= MINIMUM_ACCURACY
735735
736- index = index .consolidate_updates (partitions = 20 )
736+ index = index .consolidate_updates (retrain_index = True , partitions = 20 )
737737 _ , result = index .query (queries , k = k )
738738 assert accuracy (result , gt_i , updated_ids = updated_ids ) >= MINIMUM_ACCURACY
739739
@@ -776,7 +776,8 @@ def test_copy_centroids_uri(tmp_path):
776776 index_type = "IVF_FLAT" ,
777777 index_uri = index_uri ,
778778 input_vectors = data ,
779- copy_centroids_uri = centroids_uri
779+ copy_centroids_uri = centroids_uri ,
780+ partitions = centroids_in_size
780781 )
781782
782783 # Query the index.
@@ -931,8 +932,48 @@ def test_ingest_with_training_source_uri_tdb(tmp_path):
931932 queries = np .array ([data .transpose ()[1 ]], dtype = np .float32 )
932933 query_and_check_equals (index = index , queries = queries , expected_result_d = [[0 ]], expected_result_i = [[1 ]])
933934
935+ update_vectors = np .empty ([3 ], dtype = object )
936+ update_vectors [0 ] = np .array ([6.0 , 6.1 , 6.2 , 6.3 ], dtype = np .dtype (np .float32 ))
937+ update_vectors [1 ] = np .array ([7.0 , 7.1 , 7.2 , 7.3 ], dtype = np .dtype (np .float32 ))
938+ update_vectors [2 ] = np .array ([8.0 , 8.1 , 8.2 , 8.3 ], dtype = np .dtype (np .float32 ))
939+ index .update_batch (vectors = update_vectors , external_ids = np .array ([1000 , 1001 , 1002 ]))
940+
941+ index = index .consolidate_updates ()
942+
943+ queries = np .array ([update_vectors [2 ]], dtype = np .float32 )
944+ query_and_check_equals (index = index , queries = queries , expected_result_d = [[0 ]], expected_result_i = [[1002 ]])
945+
946+ ################################################################################################
947+ # Test we can load the index again and query, update, and consolidate.
948+ ################################################################################################
949+ # Load the index again and query.
934950 index = IVFFlatIndex (uri = index_uri )
935- query_and_check_equals (index = index , queries = queries , expected_result_d = [[0 ]], expected_result_i = [[1 ]])
951+
952+ query_and_check_equals (index = index , queries = queries , expected_result_d = [[0 ]], expected_result_i = [[1002 ]])
953+
954+ # Update the index and query.
955+ update_vectors = np .empty ([2 ], dtype = object )
956+ update_vectors [0 ] = np .array ([9.0 , 9.1 , 9.2 , 9.3 ], dtype = np .dtype (np .float32 ))
957+ update_vectors [1 ] = np .array ([10.0 , 10.1 , 10.2 , 10.3 ], dtype = np .dtype (np .float32 ))
958+ index .update_batch (vectors = update_vectors , external_ids = np .array ([1003 , 1004 ]))
959+ index = index .consolidate_updates ()
960+
961+ queries = np .array ([update_vectors [0 ]], dtype = np .float32 )
962+ query_and_check_equals (index = index , queries = queries , expected_result_d = [[0 ]], expected_result_i = [[1003 ]])
963+
964+ # Clear the index history, load, update, and query.
965+ Index .clear_history (uri = index_uri , timestamp = index .latest_ingestion_timestamp - 1 )
966+
967+ index = IVFFlatIndex (uri = index_uri )
968+
969+ update_vectors = np .empty ([2 ], dtype = object )
970+ update_vectors [0 ] = np .array ([11.0 , 11.1 , 11.2 , 11.3 ], dtype = np .dtype (np .float32 ))
971+ update_vectors [1 ] = np .array ([12.0 , 12.1 , 12.2 , 12.3 ], dtype = np .dtype (np .float32 ))
972+ index .update_batch (vectors = update_vectors , external_ids = np .array ([1003 , 1004 ]))
973+ index = index .consolidate_updates ()
974+
975+ queries = np .array ([update_vectors [0 ]], dtype = np .float32 )
976+ query_and_check_equals (index = index , queries = queries , expected_result_d = [[0 ]], expected_result_i = [[1003 ]])
936977
937978 ###############################################################################################
938979 # Also test that we can ingest with training_source_type.
@@ -984,5 +1025,39 @@ def test_ingest_with_training_source_uri_numpy(tmp_path):
9841025 queries = np .array ([data [1 ]], dtype = np .float32 )
9851026 query_and_check_equals (index = index , queries = queries , expected_result_d = [[0 ]], expected_result_i = [[1 ]])
9861027
987- index = IVFFlatIndex (uri = index_uri )
988- query_and_check_equals (index = index , queries = queries , expected_result_d = [[0 ]], expected_result_i = [[1 ]])
1028+ update_vectors = np .empty ([3 ], dtype = object )
1029+ update_vectors [0 ] = np .array ([6.0 , 6.1 , 6.2 , 6.3 ], dtype = np .dtype (np .float32 ))
1030+ update_vectors [1 ] = np .array ([7.0 , 7.1 , 7.2 , 7.3 ], dtype = np .dtype (np .float32 ))
1031+ update_vectors [2 ] = np .array ([8.0 , 8.1 , 8.2 , 8.3 ], dtype = np .dtype (np .float32 ))
1032+ index .update_batch (vectors = update_vectors , external_ids = np .array ([1000 , 1001 , 1002 ]))
1033+
1034+ index = index .consolidate_updates ()
1035+
1036+ queries = np .array ([update_vectors [2 ]], dtype = np .float32 )
1037+ query_and_check_equals (index = index , queries = queries , expected_result_d = [[0 ]], expected_result_i = [[1002 ]])
1038+
1039+ ################################################################################################
1040+ # Test we can load the index again and query, update, and consolidate.
1041+ ################################################################################################
1042+ index_ram = IVFFlatIndex (uri = index_uri )
1043+
1044+ queries = np .array ([data [1 ]], dtype = np .float32 )
1045+ query_and_check_equals (index = index , queries = queries , expected_result_d = [[0 ]], expected_result_i = [[1 ]])
1046+
1047+ update_vectors = np .empty ([2 ], dtype = object )
1048+ update_vectors [0 ] = np .array ([9.0 , 9.1 , 9.2 , 9.3 ], dtype = np .dtype (np .float32 ))
1049+ update_vectors [1 ] = np .array ([10.0 , 10.1 , 10.2 , 10.3 ], dtype = np .dtype (np .float32 ))
1050+ index .update_batch (vectors = update_vectors , external_ids = np .array ([1003 , 1004 ]))
1051+ index_ram = index_ram .consolidate_updates ()
1052+
1053+ queries = np .array ([update_vectors [0 ]], dtype = np .float32 )
1054+ query_and_check_equals (index = index , queries = queries , expected_result_d = [[0 ]], expected_result_i = [[1003 ]])
1055+
1056+ update_vectors = np .empty ([2 ], dtype = object )
1057+ update_vectors [0 ] = np .array ([11.0 , 11.1 , 11.2 , 11.3 ], dtype = np .dtype (np .float32 ))
1058+ update_vectors [1 ] = np .array ([12.0 , 12.1 , 12.2 , 12.3 ], dtype = np .dtype (np .float32 ))
1059+ index .update_batch (vectors = update_vectors , external_ids = np .array ([1003 , 1004 ]))
1060+ index_ram = index_ram .consolidate_updates (retrain_index = True , training_sample_size = 3 )
1061+
1062+ queries = np .array ([update_vectors [0 ]], dtype = np .float32 )
1063+ query_and_check_equals (index = index , queries = queries , expected_result_d = [[0 ]], expected_result_i = [[1003 ]])
0 commit comments