33
44from tiledb .cloud .dag import Mode
55from tiledb .vector_search .index import FlatIndex , IVFFlatIndex , Index
6+ import numpy as np
67
78
89def ingest (
910 index_type : str ,
1011 index_uri : str ,
11- source_uri : str ,
12- source_type : str ,
1312 * ,
13+ input_vectors : np .ndarray = None ,
14+ source_uri : str = None ,
15+ source_type : str = None ,
1416 config = None ,
1517 namespace : Optional [str ] = None ,
1618 size : int = - 1 ,
@@ -32,10 +34,12 @@ def ingest(
3234 Type of vector index (FLAT, IVF_FLAT)
3335 index_uri: str
3436 Vector index URI (stored as TileDB group)
37+ input_vectors: numpy Array
38+ Input vectors, if this is provided it takes precedence over source_uri and source_type.
3539 source_uri: str
3640 Data source URI
3741 source_type: str
38- Type of the source data
42+ Type of the source data. If left empty it is auto-detected from the suffix of source_uri
3943 config: None
4044 config dictionary, defaults to None
4145 namespace: str
@@ -88,6 +92,9 @@ def ingest(
8892 INDEX_ARRAY_NAME = storage_formats [STORAGE_VERSION ]["INDEX_ARRAY_NAME" ]
8993 IDS_ARRAY_NAME = storage_formats [STORAGE_VERSION ]["IDS_ARRAY_NAME" ]
9094 PARTS_ARRAY_NAME = storage_formats [STORAGE_VERSION ]["PARTS_ARRAY_NAME" ]
95+ INPUT_VECTORS_ARRAY_NAME = storage_formats [STORAGE_VERSION ][
96+ "INPUT_VECTORS_ARRAY_NAME"
97+ ]
9198 PARTIAL_WRITE_ARRAY_DIR = storage_formats [STORAGE_VERSION ][
9299 "PARTIAL_WRITE_ARRAY_DIR"
93100 ]
@@ -139,8 +146,22 @@ def setup(
139146
140147 return logger
141148
149+ def autodetect_source_type (source_uri : str ) -> str :
150+ if source_uri .endswith (".u8bin" ):
151+ return "U8BIN"
152+ elif source_uri .endswith (".f32bin" ):
153+ return "F32BIN"
154+ elif source_uri .endswith (".fvecs" ):
155+ return "FVEC"
156+ elif source_uri .endswith (".ivecs" ):
157+ return "IVEC"
158+ elif source_uri .endswith (".bvecs" ):
159+ return "BVEC"
160+ else :
161+ return "TILEDB_ARRAY"
162+
142163 def read_source_metadata (
143- source_uri : str , source_type : str , logger : logging . Logger
164+ source_uri : str , source_type : str = None
144165 ) -> Tuple [int , int , np .dtype ]:
145166 if source_type == "TILEDB_ARRAY" :
146167 schema = tiledb .ArraySchema .load (source_uri )
@@ -189,6 +210,53 @@ def read_source_metadata(
189210 else :
190211 raise ValueError (f"Not supported source_type { source_type } " )
191212
213+ def write_input_vectors (
214+ group : tiledb .Group ,
215+ input_vectors : np .ndarray ,
216+ size : int ,
217+ dimensions : int ,
218+ vector_type : np .dtype ,
219+ ) -> str :
220+ input_vectors_array_uri = f"{ group .uri } /{ INPUT_VECTORS_ARRAY_NAME } "
221+ if tiledb .array_exists (input_vectors_array_uri ):
222+ raise ValueError (f"Array exists { input_vectors_array_uri } " )
223+
224+ logger .debug ("Creating input vectors array" )
225+ input_vectors_array_rows_dim = tiledb .Dim (
226+ name = "rows" ,
227+ domain = (0 , dimensions - 1 ),
228+ tile = dimensions ,
229+ dtype = np .dtype (np .int32 ),
230+ )
231+ input_vectors_array_cols_dim = tiledb .Dim (
232+ name = "cols" ,
233+ domain = (0 , size - 1 ),
234+ tile = int (size / partitions ),
235+ dtype = np .dtype (np .int32 ),
236+ )
237+ input_vectors_array_dom = tiledb .Domain (
238+ input_vectors_array_rows_dim , input_vectors_array_cols_dim
239+ )
240+ input_vectors_array_attr = tiledb .Attr (
241+ name = "values" , dtype = vector_type , filters = DEFAULT_ATTR_FILTERS
242+ )
243+ input_vectors_array_schema = tiledb .ArraySchema (
244+ domain = input_vectors_array_dom ,
245+ sparse = False ,
246+ attrs = [input_vectors_array_attr ],
247+ cell_order = "col-major" ,
248+ tile_order = "col-major" ,
249+ )
250+ logger .debug (input_vectors_array_schema )
251+ tiledb .Array .create (input_vectors_array_uri , input_vectors_array_schema )
252+ group .add (input_vectors_array_uri , name = INPUT_VECTORS_ARRAY_NAME )
253+
254+ input_vectors_array = tiledb .open (input_vectors_array_uri , "w" )
255+ input_vectors_array [:, :] = np .transpose (input_vectors )
256+ input_vectors_array .close ()
257+
258+ return input_vectors_array_uri
259+
192260 def create_arrays (
193261 group : tiledb .Group ,
194262 index_type : str ,
@@ -501,7 +569,7 @@ def read_input_vectors(
501569 config : Optional [Mapping [str , Any ]] = None ,
502570 verbose : bool = False ,
503571 trace_id : Optional [str ] = None ,
504- ) -> np .array :
572+ ) -> np .ndarray :
505573 logger = setup (config , verbose )
506574 logger .debug (
507575 "Reading input vectors start_pos: %i, end_pos: %i" , start_pos , end_pos
@@ -669,7 +737,7 @@ def init_centroids(
669737 config : Optional [Mapping [str , Any ]] = None ,
670738 verbose : bool = False ,
671739 trace_id : Optional [str ] = None ,
672- ) -> np .array :
740+ ) -> np .ndarray :
673741 logger = setup (config , verbose )
674742 logger .debug (
675743 "Initialising centroids by reading the first vectors in the source data."
@@ -688,7 +756,7 @@ def init_centroids(
688756 )
689757
690758 def assign_points_and_partial_new_centroids (
691- centroids : np .array ,
759+ centroids : np .ndarray ,
692760 source_uri : str ,
693761 source_type : str ,
694762 vector_type : np .dtype ,
@@ -859,7 +927,7 @@ def ingest_flat(
859927 target .close ()
860928
861929 def write_centroids (
862- centroids : np .array ,
930+ centroids : np .ndarray ,
863931 index_group_uri : str ,
864932 partitions : int ,
865933 dimensions : int ,
@@ -1379,12 +1447,14 @@ def consolidate_and_vacuum(
13791447 index_group_uri : str ,
13801448 config : Optional [Mapping [str , Any ]] = None ,
13811449 ):
1450+ group = tiledb .Group (index_group_uri , config = config )
1451+ if INPUT_VECTORS_ARRAY_NAME in group :
1452+ tiledb .Array .delete_array (group [INPUT_VECTORS_ARRAY_NAME ].uri )
13821453 modes = ["fragment_meta" , "commits" , "array_meta" ]
13831454 for mode in modes :
13841455 conf = tiledb .Config (config )
13851456 conf ["sm.consolidation.mode" ] = mode
13861457 conf ["sm.vacuum.mode" ] = mode
1387- group = tiledb .Group (index_group_uri , config = conf )
13881458 tiledb .consolidate (group [PARTS_ARRAY_NAME ].uri , config = conf )
13891459 tiledb .vacuum (group [PARTS_ARRAY_NAME ].uri , config = conf )
13901460 if index_type == "IVF_FLAT" :
@@ -1416,9 +1486,24 @@ def consolidate_and_vacuum(
14161486 raise err
14171487 group = tiledb .Group (index_group_uri , "w" )
14181488
1419- in_size , dimensions , vector_type = read_source_metadata (
1420- source_uri = source_uri , source_type = source_type , logger = logger
1421- )
1489+ if input_vectors is not None :
1490+ in_size = input_vectors .shape [0 ]
1491+ dimensions = input_vectors .shape [1 ]
1492+ vector_type = input_vectors .dtype
1493+ source_uri = write_input_vectors (
1494+ group = group ,
1495+ input_vectors = input_vectors ,
1496+ size = in_size ,
1497+ dimensions = dimensions ,
1498+ vector_type = vector_type ,
1499+ )
1500+ source_type = "TILEDB_ARRAY"
1501+ else :
1502+ if source_type is None :
1503+ source_type = autodetect_source_type (source_uri = source_uri )
1504+ in_size , dimensions , vector_type = read_source_metadata (
1505+ source_uri = source_uri , source_type = source_type
1506+ )
14221507 if size == - 1 :
14231508 size = in_size
14241509 if size > in_size :
0 commit comments