55from tiledb .vector_search import flat_index , ivf_flat_index
66from tiledb .vector_search .index import Index
77
8+ def query_and_check (index , queries , k , expected , ** kwargs ):
9+ for _ in range (3 ):
10+ result_d , result_i = index .query (queries , k = k , ** kwargs )
11+ assert expected .issubset (set (result_i [0 ]))
812
913def test_flat_index (tmp_path ):
1014 uri = os .path .join (tmp_path , "array" )
1115 index = flat_index .create (uri = uri , dimensions = 3 , vector_type = np .dtype (np .uint8 ))
12- result_d , result_i = index .query (np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), k = 3 )
13- assert {ind .MAX_UINT64 } == set (result_i [0 ])
16+ query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {ind .MAX_UINT64 })
1417
1518 update_vectors = np .empty ([5 ], dtype = object )
1619 update_vectors [0 ] = np .array ([0 , 0 , 0 ], dtype = np .dtype (np .uint8 ))
@@ -19,39 +22,31 @@ def test_flat_index(tmp_path):
1922 update_vectors [3 ] = np .array ([3 , 3 , 3 ], dtype = np .dtype (np .uint8 ))
2023 update_vectors [4 ] = np .array ([4 , 4 , 4 ], dtype = np .dtype (np .uint8 ))
2124 index .update_batch (vectors = update_vectors , external_ids = np .array ([0 , 1 , 2 , 3 , 4 ]))
22- result_d , result_i = index .query (np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), k = 3 )
23- assert {1 , 2 , 3 }.issubset (set (result_i [0 ]))
25+ query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {1 , 2 , 3 })
2426
2527 index = index .consolidate_updates ()
26- result_d , result_i = index .query (np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), k = 3 )
27- assert {1 , 2 , 3 }.issubset (set (result_i [0 ]))
28+ query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {1 , 2 , 3 })
2829
2930 index .delete_batch (external_ids = np .array ([1 , 3 ]))
30- result_d , result_i = index .query (np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), k = 3 )
31- assert {0 , 2 , 4 }.issubset (set (result_i [0 ]))
31+ query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {0 , 2 , 4 })
3232
3333 index = index .consolidate_updates ()
34- result_d , result_i = index .query (np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), k = 3 )
35- assert {0 , 2 , 4 }.issubset (set (result_i [0 ]))
34+ query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {0 , 2 , 4 })
3635
3736 update_vectors = np .empty ([2 ], dtype = object )
3837 update_vectors [0 ] = np .array ([1 , 1 , 1 ], dtype = np .dtype (np .uint8 ))
3938 update_vectors [1 ] = np .array ([3 , 3 , 3 ], dtype = np .dtype (np .uint8 ))
4039 index .update_batch (vectors = update_vectors , external_ids = np .array ([1 , 3 ]))
41- result_d , result_i = index .query (np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), k = 3 )
42- assert {1 , 2 , 3 }.issubset (set (result_i [0 ]))
40+ query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {1 , 2 , 3 })
4341
4442 index = index .consolidate_updates ()
45- result_d , result_i = index .query (np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), k = 3 )
46- assert {1 , 2 , 3 }.issubset (set (result_i [0 ]))
43+ query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {1 , 2 , 3 })
4744
4845 index .delete_batch (external_ids = np .array ([1 , 3 ]))
49- result_d , result_i = index .query (np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), k = 3 )
50- assert {0 , 2 , 4 }.issubset (set (result_i [0 ]))
46+ query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {0 , 2 , 4 })
5147
5248 index = index .consolidate_updates ()
53- result_d , result_i = index .query (np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), k = 3 )
54- assert {0 , 2 , 4 }.issubset (set (result_i [0 ]))
49+ query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {0 , 2 , 4 })
5550
5651
5752def test_ivf_flat_index (tmp_path ):
@@ -60,10 +55,7 @@ def test_ivf_flat_index(tmp_path):
6055 index = ivf_flat_index .create (
6156 uri = uri , dimensions = 3 , vector_type = np .dtype (np .uint8 ), partitions = partitions
6257 )
63- result_d , result_i = index .query (
64- np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), k = 3 , nprobe = partitions
65- )
66- assert {ind .MAX_UINT64 } == set (result_i [0 ])
58+ query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {ind .MAX_UINT64 }, nprobe = partitions )
6759
6860 update_vectors = np .empty ([5 ], dtype = object )
6961 update_vectors [0 ] = np .array ([0 , 0 , 0 ], dtype = np .dtype (np .uint8 ))
@@ -72,52 +64,28 @@ def test_ivf_flat_index(tmp_path):
7264 update_vectors [3 ] = np .array ([3 , 3 , 3 ], dtype = np .dtype (np .uint8 ))
7365 update_vectors [4 ] = np .array ([4 , 4 , 4 ], dtype = np .dtype (np .uint8 ))
7466 index .update_batch (vectors = update_vectors , external_ids = np .array ([0 , 1 , 2 , 3 , 4 ]))
75- result_d , result_i = index .query (
76- np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), k = 3 , nprobe = partitions
77- )
78- assert {1 , 2 , 3 }.issubset (set (result_i [0 ]))
67+ query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {1 , 2 , 3 }, nprobe = partitions )
7968
8069 index = index .consolidate_updates ()
81- result_d , result_i = index .query (
82- np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), k = 3 , nprobe = partitions
83- )
84- assert {1 , 2 , 3 }.issubset (set (result_i [0 ]))
70+ query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {1 , 2 , 3 }, nprobe = partitions )
8571
8672 index .delete_batch (external_ids = np .array ([1 , 3 ]))
87- result_d , result_i = index .query (
88- np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), k = 3 , nprobe = partitions
89- )
90- assert {0 , 2 , 4 }.issubset (set (result_i [0 ]))
73+ query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {0 , 2 , 4 }, nprobe = partitions )
9174
9275 index = index .consolidate_updates ()
93- result_d , result_i = index .query (
94- np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), k = 3 , nprobe = partitions
95- )
96- assert {0 , 2 , 4 }.issubset (set (result_i [0 ]))
76+ query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {0 , 2 , 4 }, nprobe = partitions )
9777
9878 update_vectors = np .empty ([2 ], dtype = object )
9979 update_vectors [0 ] = np .array ([1 , 1 , 1 ], dtype = np .dtype (np .uint8 ))
10080 update_vectors [1 ] = np .array ([3 , 3 , 3 ], dtype = np .dtype (np .uint8 ))
10181 index .update_batch (vectors = update_vectors , external_ids = np .array ([1 , 3 ]))
102- result_d , result_i = index .query (
103- np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), k = 3 , nprobe = partitions
104- )
105- assert {1 , 2 , 3 }.issubset (set (result_i [0 ]))
82+ query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {1 , 2 , 3 }, nprobe = partitions )
10683
10784 index = index .consolidate_updates ()
108- result_d , result_i = index .query (
109- np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), k = 3 , nprobe = partitions
110- )
111- assert {1 , 2 , 3 }.issubset (set (result_i [0 ]))
85+ query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {1 , 2 , 3 }, nprobe = partitions )
11286
11387 index .delete_batch (external_ids = np .array ([1 , 3 ]))
114- result_d , result_i = index .query (
115- np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), k = 3 , nprobe = partitions
116- )
117- assert {0 , 2 , 4 }.issubset (set (result_i [0 ]))
88+ query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {0 , 2 , 4 }, nprobe = partitions )
11889
11990 index = index .consolidate_updates ()
120- result_d , result_i = index .query (
121- np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), k = 3 , nprobe = partitions
122- )
123- assert {0 , 2 , 4 }.issubset (set (result_i [0 ]))
91+ query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {0 , 2 , 4 }, nprobe = partitions )
0 commit comments