66
77def ingest_embeddings_with_driver (
88 object_index_uri : str ,
9- embeddings_uri : str ,
9+ use_updates_array : bool ,
1010 metadata_array_uri : str = None ,
1111 index_timestamp : int = None ,
1212 workers : int = - 1 ,
@@ -30,7 +30,7 @@ def ingest_embeddings_with_driver(
3030):
3131 def ingest_embeddings (
3232 object_index_uri : str ,
33- embeddings_uri : str ,
33+ use_updates_array : bool ,
3434 metadata_array_uri : str = None ,
3535 index_timestamp : int = None ,
3636 workers : int = - 1 ,
@@ -81,6 +81,7 @@ def install_extra_worker_modules():
8181 from tiledb .vector_search import ingest
8282 from tiledb .vector_search .object_api import ObjectIndex
8383 from tiledb .vector_search .object_readers import ObjectPartition
84+ from tiledb .vector_search .storage_formats import storage_formats
8485
8586 MAX_TASKS_PER_STAGE = 100
8687 DEFAULT_IMG_NAME = "3.9-vectorsearch"
@@ -121,14 +122,9 @@ def setup(
121122 # UDFs
122123 # --------------------------------------------------------------------
123124 def compute_embeddings_udf (
124- object_reader_source_code : str ,
125- object_reader_class_name : str ,
126- object_reader_kwargs : Dict ,
127- object_embedding_source_code : str ,
128- object_embedding_class_name : str ,
129- object_embedding_kwargs : Dict ,
125+ object_index_uri : str ,
130126 partition_dicts : List [Dict ],
131- embeddings_uri : str ,
127+ use_updates_array : bool ,
132128 metadata_array_uri : str = None ,
133129 index_timestamp : int = None ,
134130 verbose : bool = False ,
@@ -182,16 +178,15 @@ def instantiate_object(code, class_name, **kwargs):
182178 return class_ (** kwargs )
183179
184180 logger = setup (config , verbose )
185- object_reader = instantiate_object (
186- code = object_reader_source_code ,
187- class_name = object_reader_class_name ,
188- ** object_reader_kwargs ,
189- )
190- object_embedding = instantiate_object (
191- code = object_embedding_source_code ,
192- class_name = object_embedding_class_name ,
193- ** object_embedding_kwargs ,
181+ obj_index = object_index .ObjectIndex (
182+ object_index_uri ,
183+ config = config ,
184+ environment_variables = environment_variables ,
185+ load_embedding = False ,
186+ load_metadata_in_memory = False ,
194187 )
188+ object_reader = obj_index .object_reader
189+ object_embedding = obj_index .embedding
195190 for var , val in environment_variables .items ():
196191 os .environ [var ] = val
197192 with tiledb .scope_ctx (ctx_or_config = config ):
@@ -201,17 +196,23 @@ def instantiate_object(code, class_name, **kwargs):
201196 object_embedding .dimensions ()
202197 vector_type = object_embedding .vector_type ()
203198
204- logger .debug ("embeddings_uri %s" , embeddings_uri )
205- embeddings_array = tiledb .open (
206- embeddings_uri , "w" , timestamp = index_timestamp
207- )
199+ if not use_updates_array :
200+ embeddings_array_name = storage_formats [
201+ obj_index .index .storage_version
202+ ]["INPUT_VECTORS_ARRAY_NAME" ]
203+ embeddings_array_uri = f"{ obj_index .uri } /{ embeddings_array_name } "
204+ logger .debug ("embeddings_uri %s" , embeddings_array_uri )
205+ embeddings_array = tiledb .open (
206+ embeddings_array_uri , "w" , timestamp = index_timestamp
207+ )
208+
208209 if metadata_array_uri is not None :
209210 metadata_array = tiledb .open (
210211 metadata_array_uri , "w" , timestamp = index_timestamp
211212 )
212213 for partition_dict in partition_dicts :
213214 partition = instantiate_object (
214- code = object_reader_source_code ,
215+ code = obj_index . object_reader_source_code ,
215216 class_name = object_reader .partition_class_name (),
216217 ** partition_dict ,
217218 )
@@ -224,22 +225,36 @@ def instantiate_object(code, class_name, **kwargs):
224225 embeddings = object_embedding .embed (objects , metadata )
225226
226227 logger .debug ("Write embeddings partition_id: %d" , partition_id )
227- embeddings_flattened = np .empty (1 , dtype = "O" )
228- embeddings_flattened [0 ] = embeddings .astype (vector_type ).flatten ()
229- embeddings_shape = np .empty (1 , dtype = "O" )
230- embeddings_shape [0 ] = np .array (embeddings .shape , dtype = np .uint32 )
231- external_ids = np .empty (1 , dtype = "O" )
232- external_ids [0 ] = objects ["external_id" ].astype (np .uint64 )
233- embeddings_array [partition_id ] = {
234- "vectors" : embeddings_flattened ,
235- "vectors_shape" : embeddings_shape ,
236- "external_ids" : external_ids ,
237- }
228+ if use_updates_array :
229+ vectors = np .empty (embeddings .shape [0 ], dtype = "O" )
230+ for i in range (embeddings .shape [0 ]):
231+ vectors [i ] = embeddings [i ].astype (vector_type )
232+ obj_index .index .update_batch (
233+ vectors = vectors ,
234+ external_ids = objects ["external_id" ].astype (np .uint64 ),
235+ )
236+ else :
237+ embeddings_flattened = np .empty (1 , dtype = "O" )
238+ embeddings_flattened [0 ] = embeddings .astype (
239+ vector_type
240+ ).flatten ()
241+ embeddings_shape = np .empty (1 , dtype = "O" )
242+ embeddings_shape [0 ] = np .array (
243+ embeddings .shape , dtype = np .uint32
244+ )
245+ external_ids = np .empty (1 , dtype = "O" )
246+ external_ids [0 ] = objects ["external_id" ].astype (np .uint64 )
247+ embeddings_array [partition_id ] = {
248+ "vectors" : embeddings_flattened ,
249+ "vectors_shape" : embeddings_shape ,
250+ "external_ids" : external_ids ,
251+ }
238252 if metadata_array_uri is not None :
239253 external_ids = metadata .pop ("external_id" , None )
240254 metadata_array [external_ids ] = metadata
241255
242- embeddings_array .close ()
256+ if not use_updates_array :
257+ embeddings_array .close ()
243258 if metadata_array_uri is not None :
244259 metadata_array .close ()
245260
@@ -254,8 +269,8 @@ def submit_local(d, func, *args, **kwargs):
254269 return d .submit_local (func , * args , ** kwargs )
255270
256271 def create_dag (
257- ob_index : ObjectIndex ,
258- embeddings_uri : str ,
272+ obj_index : ObjectIndex ,
273+ use_updates_array : bool ,
259274 partitions : List [ObjectPartition ],
260275 object_partitions_per_worker : int ,
261276 object_work_tasks : int ,
@@ -322,14 +337,9 @@ def create_dag(
322337 ] = worker_access_credentials_name
323338 submit (
324339 compute_embeddings_udf ,
325- object_reader_source_code = ob_index .object_reader_source_code ,
326- object_reader_class_name = ob_index .object_reader_class_name ,
327- object_reader_kwargs = ob_index .object_reader_kwargs ,
328- object_embedding_source_code = ob_index .embedding_source_code ,
329- object_embedding_class_name = ob_index .embedding_class_name ,
330- object_embedding_kwargs = ob_index .embedding_kwargs ,
340+ object_index_uri = obj_index .uri ,
331341 partition_dicts = partition_dicts ,
332- embeddings_uri = embeddings_uri ,
342+ use_updates_array = use_updates_array ,
333343 metadata_array_uri = metadata_array_uri ,
334344 index_timestamp = index_timestamp ,
335345 verbose = verbose ,
@@ -361,12 +371,14 @@ def create_dag(
361371
362372 from tiledb .vector_search .object_api import object_index
363373
364- ob_index = object_index .ObjectIndex (
374+ obj_index = object_index .ObjectIndex (
365375 object_index_uri ,
366376 config = config ,
367377 environment_variables = environment_variables ,
378+ load_embedding = False ,
379+ load_metadata_in_memory = False ,
368380 )
369- partitions = ob_index .object_reader .get_partitions (** kwargs )
381+ partitions = obj_index .object_reader .get_partitions (** kwargs )
370382 object_partitions = len (partitions )
371383 object_partitions_per_worker = 1
372384 if max_tasks_per_stage == - 1 :
@@ -395,8 +407,8 @@ def create_dag(
395407
396408 logger .debug ("Creating ingestion graph" )
397409 d = create_dag (
398- ob_index = ob_index ,
399- embeddings_uri = embeddings_uri ,
410+ obj_index = obj_index ,
411+ use_updates_array = use_updates_array ,
400412 partitions = partitions ,
401413 object_partitions_per_worker = object_partitions_per_worker ,
402414 object_work_tasks = object_work_tasks ,
@@ -419,19 +431,29 @@ def create_dag(
419431 logger .debug ("Submitted ingestion graph" )
420432 d .wait ()
421433
422- ob_index .index = ingest (
423- index_type = ob_index .index_type ,
424- index_uri = ob_index .uri ,
425- source_uri = embeddings_uri ,
426- source_type = "TILEDB_PARTITIONED_ARRAY" ,
427- external_ids_uri = embeddings_uri ,
428- external_ids_type = "TILEDB_PARTITIONED_ARRAY" ,
429- index_timestamp = index_timestamp ,
430- storage_version = ob_index .index .storage_version ,
431- config = config ,
432- mode = vector_indexing_mode ,
433- ** kwargs ,
434- )
434+ if use_updates_array :
435+ obj_index .index .consolidate_updates (
436+ mode = vector_indexing_mode ,
437+ ** kwargs ,
438+ )
439+ else :
440+ embeddings_array_name = storage_formats [
441+ obj_index .index .storage_version
442+ ]["INPUT_VECTORS_ARRAY_NAME" ]
443+ embeddings_array_uri = f"{ obj_index .uri } /{ embeddings_array_name } "
444+ obj_index .index = ingest (
445+ index_type = obj_index .index_type ,
446+ index_uri = obj_index .uri ,
447+ source_uri = embeddings_array_uri ,
448+ source_type = "TILEDB_PARTITIONED_ARRAY" ,
449+ external_ids_uri = embeddings_array_uri ,
450+ external_ids_type = "TILEDB_PARTITIONED_ARRAY" ,
451+ index_timestamp = index_timestamp ,
452+ storage_version = obj_index .index .storage_version ,
453+ config = config ,
454+ mode = vector_indexing_mode ,
455+ ** kwargs ,
456+ )
435457
436458 def submit_local (d , func , * args , ** kwargs ):
437459 # Drop kwarg
@@ -468,7 +490,7 @@ def submit_local(d, func, *args, **kwargs):
468490 submit (
469491 ingest_embeddings ,
470492 object_index_uri = object_index_uri ,
471- embeddings_uri = embeddings_uri ,
493+ use_updates_array = use_updates_array ,
472494 metadata_array_uri = metadata_array_uri ,
473495 index_timestamp = index_timestamp ,
474496 max_tasks_per_stage = max_tasks_per_stage ,
0 commit comments