1919from tiledb .vector_search .vamana_index import VamanaIndex
2020
2121
22+ def query_and_check_distances (
23+ index , queries , k , expected_distances , expected_ids , ** kwargs
24+ ):
25+ for _ in range (1 ):
26+ distances , ids = index .query (queries , k = k , ** kwargs )
27+ assert np .array_equal (ids , expected_ids )
28+ assert np .array_equal (distances , expected_distances )
29+
30+
2231def query_and_check (index , queries , k , expected , ** kwargs ):
2332 for _ in range (3 ):
2433 result_d , result_i = index .query (queries , k = k , ** kwargs )
@@ -167,7 +176,7 @@ def test_ivf_flat_index(tmp_path):
167176 )
168177
169178
170- def test_vamana_index (tmp_path ):
179+ def test_vamana_index_simple (tmp_path ):
171180 uri = os .path .join (tmp_path , "array" )
172181 dimensions = 3
173182 vector_type = np .dtype (np .uint8 )
@@ -188,14 +197,68 @@ def test_vamana_index(tmp_path):
188197 query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {ind .MAX_UINT64 })
189198
190199
200+ def test_vamana_index (tmp_path ):
201+ uri = os .path .join (tmp_path , "array" )
202+ if os .path .exists (uri ):
203+ os .rmdir (uri )
204+ vector_type = np .float32
205+
206+ index = vamana_index .create (
207+ uri = uri ,
208+ dimensions = 3 ,
209+ vector_type = np .dtype (vector_type ),
210+ id_type = np .dtype (np .uint32 ),
211+ )
212+
213+ queries = np .array ([[2 , 2 , 2 ]], dtype = np .float32 )
214+ distances , ids = index .query (queries , k = 1 )
215+ assert distances .shape == (1 , 1 )
216+ assert ids .shape == (1 , 1 )
217+ assert distances [0 ][0 ] == ind .MAX_FLOAT_32
218+ assert ids [0 ][0 ] == ind .MAX_UINT64
219+ query_and_check_distances (
220+ index , queries , 1 , [[ind .MAX_FLOAT_32 ]], [[ind .MAX_UINT64 ]]
221+ )
222+
223+ update_vectors = np .empty ([5 ], dtype = object )
224+ update_vectors [0 ] = np .array ([0 , 0 , 0 ], dtype = np .dtype (np .float32 ))
225+ update_vectors [1 ] = np .array ([1 , 1 , 1 ], dtype = np .dtype (np .float32 ))
226+ update_vectors [2 ] = np .array ([2 , 2 , 2 ], dtype = np .dtype (np .float32 ))
227+ update_vectors [3 ] = np .array ([3 , 3 , 3 ], dtype = np .dtype (np .float32 ))
228+ update_vectors [4 ] = np .array ([4 , 4 , 4 ], dtype = np .dtype (np .float32 ))
229+ index .update_batch (
230+ vectors = update_vectors ,
231+ external_ids = np .array ([0 , 1 , 2 , 3 , 4 ], dtype = np .dtype (np .uint32 )),
232+ )
233+ query_and_check_distances (
234+ index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 2 , [[0 , 3 ]], [[2 , 1 ]]
235+ )
236+
237+ index = index .consolidate_updates ()
238+
239+ # TODO(paris): Does not work with k > 1 or with [0, 0, 0] as the query.
240+ query_and_check_distances (
241+ index , np .array ([[1 , 1 , 1 ]], dtype = np .float32 ), 1 , [[0 ]], [[1 ]]
242+ )
243+ query_and_check_distances (
244+ index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 1 , [[0 ]], [[2 ]]
245+ )
246+ query_and_check_distances (
247+ index , np .array ([[3 , 3 , 3 ]], dtype = np .float32 ), 1 , [[0 ]], [[3 ]]
248+ )
249+ query_and_check_distances (
250+ index , np .array ([[4 , 4 , 4 ]], dtype = np .float32 ), 1 , [[0 ]], [[4 ]]
251+ )
252+
253+
191254def test_delete_invalid_index (tmp_path ):
192255 # We don't throw with an invalid uri.
193256 Index .delete_index (uri = "invalid_uri" , config = tiledb .cloud .Config ())
194257
195258
196259def test_delete_index (tmp_path ):
197- indexes = ["FLAT" , "IVF_FLAT" ]
198- index_classes = [FlatIndex , IVFFlatIndex ]
260+ indexes = ["FLAT" , "IVF_FLAT" , "VAMANA" ]
261+ index_classes = [FlatIndex , IVFFlatIndex , VamanaIndex ]
199262 data = np .array ([[1.0 , 1.1 , 1.2 , 1.3 ], [2.0 , 2.1 , 2.2 , 2.3 ]], dtype = np .float32 )
200263 for index_type , index_class in zip (indexes , index_classes ):
201264 index_uri = os .path .join (tmp_path , f"array_{ index_type } " )
@@ -229,7 +292,7 @@ def test_index_with_incorrect_dimensions(tmp_path):
229292def test_index_with_incorrect_num_of_query_columns_simple (tmp_path ):
230293 siftsmall_uri = siftsmall_inputs_file
231294 queries_uri = siftsmall_query_file
232- indexes = ["FLAT" , "IVF_FLAT" ]
295+ indexes = ["FLAT" , "IVF_FLAT" , "VAMANA" ]
233296 for index_type in indexes :
234297 index_uri = os .path .join (tmp_path , f"sift10k_flat_{ index_type } " )
235298 index = ingest (
@@ -253,7 +316,7 @@ def test_index_with_incorrect_num_of_query_columns_complex(tmp_path):
253316 # Tests that we raise a TypeError if the number of columns in the query is not the same as the
254317 # number of columns in the indexed data.
255318 size = 1000
256- indexes = ["FLAT" , "IVF_FLAT" ]
319+ indexes = ["FLAT" , "IVF_FLAT" , "VAMANA" ]
257320 num_columns_in_vector = [1 , 2 , 3 , 4 , 5 , 10 ]
258321 for index_type in indexes :
259322 for num_columns in num_columns_in_vector :
@@ -298,7 +361,7 @@ def test_index_with_incorrect_num_of_query_columns_in_single_vector_query(tmp_pa
298361 # Tests that we raise a TypeError if the number of columns in the query is not the same as the
299362 # number of columns in the indexed data, specifically for a single vector query.
300363 # i.e. queries = [1, 2, 3] instead of queries = [[1, 2, 3], [4, 5, 6]].
301- indexes = [flat_index , ivf_flat_index ]
364+ indexes = [flat_index , ivf_flat_index , vamana_index ]
302365 for index_type in indexes :
303366 uri = os .path .join (tmp_path , f"array_{ index_type .__name__ } " )
304367 index = index_type .create (uri = uri , dimensions = 3 , vector_type = np .dtype (np .uint8 ))
0 commit comments