1+ import json
2+ from typing import Any , Mapping
3+
14import numpy as np
25
6+ from tiledb .vector_search import index
37from tiledb .vector_search .module import *
4- from tiledb .vector_search .storage_formats import storage_formats
5- from tiledb .vector_search .index import Index
6- from typing import Any , Mapping
8+ from tiledb .vector_search .storage_formats import (STORAGE_VERSION ,
9+ storage_formats )
710
11+ MAX_INT32 = np .iinfo (np .dtype ("int32" )).max
12+ TILE_SIZE_BYTES = 128000000 # 128MB
13+ INDEX_TYPE = "FLAT"
814
9- class FlatIndex (Index ):
15+
16+ class FlatIndex (index .Index ):
1017 """
1118 Open a flat index
1219
@@ -23,41 +30,50 @@ def __init__(
2330 uri : str ,
2431 config : Optional [Mapping [str , Any ]] = None ,
2532 timestamp = None ,
33+ ** kwargs ,
2634 ):
2735 super ().__init__ (uri = uri , config = config , timestamp = timestamp )
28- self .index_type = "FLAT"
36+ self .index_type = INDEX_TYPE
2937 self ._index = None
30- self .db_uri = self .group [storage_formats [self .storage_version ]["PARTS_ARRAY_NAME" ] + self .index_version ].uri
31- schema = tiledb .ArraySchema .load (
32- self .db_uri , ctx = tiledb .Ctx (self .config )
33- )
38+ self .db_uri = self .group [
39+ storage_formats [self .storage_version ]["PARTS_ARRAY_NAME" ]
40+ + self .index_version
41+ ].uri
42+ schema = tiledb .ArraySchema .load (self .db_uri , ctx = tiledb .Ctx (self .config ))
3443 if self .base_size == - 1 :
3544 self .size = schema .domain .dim (1 ).domain [1 ] + 1
3645 else :
3746 self .size = self .base_size
38- self ._db = load_as_matrix (
39- self .db_uri ,
40- ctx = self .ctx ,
41- config = config ,
42- size = self .size ,
43- timestamp = self .base_array_timestamp ,
44- )
45- # Check for existence of ids array. Previous versions were not using external_ids in the ingestion assuming
46- # that the external_ids were the position of the vector in the array.
47- if storage_formats [self .storage_version ]["IDS_ARRAY_NAME" ] + self .index_version in self .group :
47+
48+ self .dtype = np .dtype (self .group .meta .get ("dtype" , None ))
49+ if (
50+ storage_formats [self .storage_version ]["IDS_ARRAY_NAME" ] + self .index_version
51+ in self .group
52+ ):
4853 self .ids_uri = self .group [
49- storage_formats [self .storage_version ]["IDS_ARRAY_NAME" ] + self .index_version
54+ storage_formats [self .storage_version ]["IDS_ARRAY_NAME" ]
55+ + self .index_version
5056 ].uri
51- self ._ids = read_vector_u64 (self .ctx , self .ids_uri , 0 , self .size , self .base_array_timestamp )
5257 else :
5358 self .ids_uri = ""
54- self ._ids = StdVector_u64 (np .arange (self .size ).astype (np .uint64 ))
55-
56- dtype = self .group .meta .get ("dtype" , None )
57- if dtype is None :
58- self .dtype = self ._db .dtype
59- else :
60- self .dtype = np .dtype (dtype )
59+ if self .size > 0 :
60+ self ._db = load_as_matrix (
61+ self .db_uri ,
62+ ctx = self .ctx ,
63+ config = config ,
64+ size = self .size ,
65+ timestamp = self .base_array_timestamp ,
66+ )
67+ if self .dtype is None :
68+ self .dtype = self ._db .dtype
69+ # Check for existence of ids array. Previous versions were not using external_ids in the ingestion assuming
70+ # that the external_ids were the position of the vector in the array.
71+ if self .ids_uri == "" :
72+ self ._ids = StdVector_u64 (np .arange (self .size ).astype (np .uint64 ))
73+ else :
74+ self ._ids = read_vector_u64 (
75+ self .ctx , self .ids_uri , 0 , self .size , self .base_array_timestamp
76+ )
6177
6278 def query_internal (
6379 self ,
@@ -80,10 +96,92 @@ def query_internal(
8096 # TODO:
8197 # - typecheck queries
8298 # - add all the options and query strategies
99+ if self .size == 0 :
100+ return np .full ((queries .shape [0 ], k ), index .MAX_FLOAT_32 ), np .full (
101+ (queries .shape [0 ], k ), index .MAX_UINT64
102+ )
83103
84104 assert queries .dtype == np .float32
85105
86106 queries_m = array_to_matrix (np .transpose (queries ))
87107 d , i = query_vq_heap (self ._db , queries_m , self ._ids , k , nthreads )
88108
89109 return np .transpose (np .array (d )), np .transpose (np .array (i ))
110+
111+
112+ def create (
113+ uri : str ,
114+ dimensions : int ,
115+ vector_type : np .dtype ,
116+ group_exists : bool = False ,
117+ config : Optional [Mapping [str , Any ]] = None ,
118+ ** kwargs ,
119+ ) -> FlatIndex :
120+ index .create_metadata (
121+ uri = uri ,
122+ dimensions = dimensions ,
123+ vector_type = vector_type ,
124+ index_type = INDEX_TYPE ,
125+ group_exists = group_exists ,
126+ config = config ,
127+ )
128+ with tiledb .scope_ctx (ctx_or_config = config ):
129+ group = tiledb .Group (uri , "w" )
130+ tile_size = TILE_SIZE_BYTES / np .dtype (vector_type ).itemsize / dimensions
131+ ids_array_name = storage_formats [STORAGE_VERSION ]["IDS_ARRAY_NAME" ]
132+ parts_array_name = storage_formats [STORAGE_VERSION ]["PARTS_ARRAY_NAME" ]
133+ ids_uri = f"{ uri } /{ ids_array_name } "
134+ parts_uri = f"{ uri } /{ parts_array_name } "
135+
136+ ids_array_rows_dim = tiledb .Dim (
137+ name = "rows" ,
138+ domain = (0 , MAX_INT32 ),
139+ tile = tile_size ,
140+ dtype = np .dtype (np .int32 ),
141+ )
142+ ids_array_dom = tiledb .Domain (ids_array_rows_dim )
143+ ids_attr = tiledb .Attr (
144+ name = "values" ,
145+ dtype = np .dtype (np .uint64 ),
146+ filters = storage_formats [STORAGE_VERSION ]["DEFAULT_ATTR_FILTERS" ],
147+ )
148+ ids_schema = tiledb .ArraySchema (
149+ domain = ids_array_dom ,
150+ sparse = False ,
151+ attrs = [ids_attr ],
152+ cell_order = "col-major" ,
153+ tile_order = "col-major" ,
154+ )
155+ tiledb .Array .create (ids_uri , ids_schema )
156+ group .add (ids_uri , name = ids_array_name )
157+
158+ parts_array_rows_dim = tiledb .Dim (
159+ name = "rows" ,
160+ domain = (0 , dimensions - 1 ),
161+ tile = dimensions ,
162+ dtype = np .dtype (np .int32 ),
163+ )
164+ parts_array_cols_dim = tiledb .Dim (
165+ name = "cols" ,
166+ domain = (0 , MAX_INT32 ),
167+ tile = tile_size ,
168+ dtype = np .dtype (np .int32 ),
169+ )
170+ parts_array_dom = tiledb .Domain (parts_array_rows_dim , parts_array_cols_dim )
171+ parts_attr = tiledb .Attr (
172+ name = "values" ,
173+ dtype = vector_type ,
174+ filters = storage_formats [STORAGE_VERSION ]["DEFAULT_ATTR_FILTERS" ],
175+ )
176+ parts_schema = tiledb .ArraySchema (
177+ domain = parts_array_dom ,
178+ sparse = False ,
179+ attrs = [parts_attr ],
180+ cell_order = "col-major" ,
181+ tile_order = "col-major" ,
182+ )
183+ tiledb .Array .create (parts_uri , parts_schema )
184+ group .add (parts_uri , name = parts_array_name )
185+
186+ group .close ()
187+ return FlatIndex (uri = uri , config = config )
0 commit comments