22import pytest
33from array_paths import *
44from common import *
5+ import json
56
67import tiledb .vector_search .index as ind
78from tiledb .vector_search import Index
89from tiledb .vector_search import flat_index
910from tiledb .vector_search import ivf_flat_index
11+ from tiledb .vector_search .index import create_metadata
12+ from tiledb .vector_search .index import DATASET_TYPE
1013from tiledb .vector_search .flat_index import FlatIndex
1114from tiledb .vector_search .ingestion import ingest
1215from tiledb .vector_search .ivf_flat_index import IVFFlatIndex
@@ -18,18 +21,49 @@ def query_and_check(index, queries, k, expected, **kwargs):
1821 result_d , result_i = index .query (queries , k = k , ** kwargs )
1922 assert expected .issubset (set (result_i [0 ]))
2023
24+ def check_default_metadata (uri , expected_vector_type , expected_storage_version , expected_index_type ):
25+ group = tiledb .Group (uri , "r" , ctx = tiledb .Ctx (None ))
26+ assert "dataset_type" in group .meta
27+ assert group .meta ["dataset_type" ] == DATASET_TYPE
28+ assert type (group .meta ["dataset_type" ]) == str
29+
30+ assert "dtype" in group .meta
31+ assert group .meta ["dtype" ] == np .dtype (expected_vector_type ).name
32+ assert type (group .meta ["dtype" ]) == str
33+
34+ assert "storage_version" in group .meta
35+ assert group .meta ["storage_version" ] == expected_storage_version
36+ assert type (group .meta ["storage_version" ]) == str
37+
38+ assert "index_type" in group .meta
39+ assert group .meta ["index_type" ] == expected_index_type
40+ assert type (group .meta ["index_type" ]) == str
41+
42+ assert "base_sizes" in group .meta
43+ assert group .meta ["base_sizes" ] == json .dumps ([0 ])
44+ assert type (group .meta ["base_sizes" ]) == str
45+
46+ assert "ingestion_timestamps" in group .meta
47+ assert group .meta ["ingestion_timestamps" ] == json .dumps ([0 ])
48+ assert type (group .meta ["ingestion_timestamps" ]) == str
49+
50+ assert "has_updates" in group .meta
51+ assert group .meta ["has_updates" ] == False
52+ assert type (group .meta ["has_updates" ]) == np .int64
2153
2254def test_flat_index (tmp_path ):
2355 uri = os .path .join (tmp_path , "array" )
24- index = flat_index .create (uri = uri , dimensions = 3 , vector_type = np .dtype (np .uint8 ))
56+ vector_type = np .dtype (np .uint8 )
57+ index = flat_index .create (uri = uri , dimensions = 3 , vector_type = vector_type )
2558 query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {ind .MAX_UINT64 })
59+ check_default_metadata (uri , vector_type , STORAGE_VERSION , "FLAT" )
2660
2761 update_vectors = np .empty ([5 ], dtype = object )
28- update_vectors [0 ] = np .array ([0 , 0 , 0 ], dtype = np . dtype ( np . uint8 ) )
29- update_vectors [1 ] = np .array ([1 , 1 , 1 ], dtype = np . dtype ( np . uint8 ) )
30- update_vectors [2 ] = np .array ([2 , 2 , 2 ], dtype = np . dtype ( np . uint8 ) )
31- update_vectors [3 ] = np .array ([3 , 3 , 3 ], dtype = np . dtype ( np . uint8 ) )
32- update_vectors [4 ] = np .array ([4 , 4 , 4 ], dtype = np . dtype ( np . uint8 ) )
62+ update_vectors [0 ] = np .array ([0 , 0 , 0 ], dtype = vector_type )
63+ update_vectors [1 ] = np .array ([1 , 1 , 1 ], dtype = vector_type )
64+ update_vectors [2 ] = np .array ([2 , 2 , 2 ], dtype = vector_type )
65+ update_vectors [3 ] = np .array ([3 , 3 , 3 ], dtype = vector_type )
66+ update_vectors [4 ] = np .array ([4 , 4 , 4 ], dtype = vector_type )
3367 index .update_batch (vectors = update_vectors , external_ids = np .array ([0 , 1 , 2 , 3 , 4 ]))
3468 query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {1 , 2 , 3 })
3569
@@ -43,8 +77,8 @@ def test_flat_index(tmp_path):
4377 query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {0 , 2 , 4 })
4478
4579 update_vectors = np .empty ([2 ], dtype = object )
46- update_vectors [0 ] = np .array ([1 , 1 , 1 ], dtype = np . dtype ( np . uint8 ) )
47- update_vectors [1 ] = np .array ([3 , 3 , 3 ], dtype = np . dtype ( np . uint8 ) )
80+ update_vectors [0 ] = np .array ([1 , 1 , 1 ], dtype = vector_type )
81+ update_vectors [1 ] = np .array ([3 , 3 , 3 ], dtype = vector_type )
4882 index .update_batch (vectors = update_vectors , external_ids = np .array ([1 , 3 ]))
4983 query_and_check (index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {1 , 2 , 3 })
5084
@@ -61,9 +95,9 @@ def test_flat_index(tmp_path):
6195def test_ivf_flat_index (tmp_path ):
6296 partitions = 10
6397 uri = os .path .join (tmp_path , "array" )
64-
98+ vector_type = np . dtype ( np . uint8 )
6599 index = ivf_flat_index .create (
66- uri = uri , dimensions = 3 , vector_type = np . dtype ( np . uint8 ) , partitions = partitions
100+ uri = uri , dimensions = 3 , vector_type = vector_type , partitions = partitions
67101 )
68102 query_and_check (
69103 index ,
@@ -72,13 +106,14 @@ def test_ivf_flat_index(tmp_path):
72106 {ind .MAX_UINT64 },
73107 nprobe = partitions ,
74108 )
109+ check_default_metadata (uri , vector_type , STORAGE_VERSION , "IVF_FLAT" )
75110
76111 update_vectors = np .empty ([5 ], dtype = object )
77- update_vectors [0 ] = np .array ([0 , 0 , 0 ], dtype = np . dtype ( np . uint8 ) )
78- update_vectors [1 ] = np .array ([1 , 1 , 1 ], dtype = np . dtype ( np . uint8 ) )
79- update_vectors [2 ] = np .array ([2 , 2 , 2 ], dtype = np . dtype ( np . uint8 ) )
80- update_vectors [3 ] = np .array ([3 , 3 , 3 ], dtype = np . dtype ( np . uint8 ) )
81- update_vectors [4 ] = np .array ([4 , 4 , 4 ], dtype = np . dtype ( np . uint8 ) )
112+ update_vectors [0 ] = np .array ([0 , 0 , 0 ], dtype = vector_type )
113+ update_vectors [1 ] = np .array ([1 , 1 , 1 ], dtype = vector_type )
114+ update_vectors [2 ] = np .array ([2 , 2 , 2 ], dtype = vector_type )
115+ update_vectors [3 ] = np .array ([3 , 3 , 3 ], dtype = vector_type )
116+ update_vectors [4 ] = np .array ([4 , 4 , 4 ], dtype = vector_type )
82117 index .update_batch (vectors = update_vectors , external_ids = np .array ([0 , 1 , 2 , 3 , 4 ]))
83118
84119 query_and_check (
@@ -102,8 +137,8 @@ def test_ivf_flat_index(tmp_path):
102137 )
103138
104139 update_vectors = np .empty ([2 ], dtype = object )
105- update_vectors [0 ] = np .array ([1 , 1 , 1 ], dtype = np . dtype ( np . uint8 ) )
106- update_vectors [1 ] = np .array ([3 , 3 , 3 ], dtype = np . dtype ( np . uint8 ) )
140+ update_vectors [0 ] = np .array ([1 , 1 , 1 ], dtype = vector_type )
141+ update_vectors [1 ] = np .array ([3 , 3 , 3 ], dtype = vector_type )
107142 index .update_batch (vectors = update_vectors , external_ids = np .array ([1 , 3 ]))
108143 query_and_check (
109144 index , np .array ([[2 , 2 , 2 ]], dtype = np .float32 ), 3 , {1 , 2 , 3 }, nprobe = partitions
@@ -251,3 +286,17 @@ def test_index_with_incorrect_num_of_query_columns_in_single_vector_query(tmp_pa
251286 # TODO: This also throws a TypeError for incorrect dimension
252287 with pytest .raises (TypeError ):
253288 index .query (np .array ([1 , 1 , 1 ], dtype = np .float32 ), k = 3 )
289+
290+ def test_create_metadata (tmp_path ):
291+ uri = os .path .join (tmp_path , "array" )
292+
293+ # Create the metadata at the specified URI.
294+ dimensions = 3
295+ vector_type : np .dtype = np .dtype (np .uint8 )
296+ index_type : str = "IVF_FLAT"
297+ storage_version : str = STORAGE_VERSION
298+ group_exists : bool = False
299+ create_metadata (uri , dimensions , vector_type , index_type , storage_version , group_exists )
300+
301+ # Check it contains the default metadata.
302+ check_default_metadata (uri , vector_type , storage_version , index_type )
0 commit comments