@@ -79,20 +79,6 @@ def process_source_id(
7979 ):
8080 return
8181
82- embs = db .match_window_ids (
83- deployments_filter = config_dict .create (
84- eq = dict (project = source_id .dataset_name )
85- ),
86- recordings_filter = config_dict .create (eq = dict (filename = source_id .file_id )),
87- windows_filter = config_dict .create (
88- approx = dict (
89- offsets = [source_id .offset_s , source_id .offset_s + window_size_s ],
90- )
91- ),
92- )
93- if embs :
94- return
95-
9682 outputs = worker .embedding_model .embed (audio_array )
9783 logits_key = state ['worker' ].model_config .logits_key
9884 if logits_key is None :
@@ -307,11 +293,19 @@ def embedding_exists(
307293 )
308294 return bool (embs )
309295
310- def process_all (self , target_dataset_name : str | None = None , batch_size = 32 ):
296+ def process_all (
297+ self ,
298+ target_dataset_name : str | None = None ,
299+ batch_size = 32 ,
300+ handle_duplicates = 'error' ,
301+ ):
311302 """Process all audio examples."""
312303
313304 # Update model config and audio sources in the database.
314305 self .update_configs ()
306+ if self .db .count_embeddings () == 0 :
307+ # No chance of duplicates, so we can use "allow" mode.
308+ handle_duplicates = 'allow'
315309
316310 # Create missing deployments and recordings in the database.
317311 source_id_to_deployment_id = {}
@@ -372,14 +366,19 @@ def process_all(self, target_dataset_name: str | None = None, batch_size=32):
372366 for result in got :
373367 if result is None :
374368 continue
375- for source , offsets , embedding in zip (* result ):
376- source_id = source .to_id ()
377- recording_id = source_id_to_recording_id [source_id ]
378- self .db .insert_window (
379- recording_id = recording_id ,
380- embedding = embedding ,
381- offsets = offsets ,
382- )
369+ windows_batch = [
370+ {
371+ 'recording_id' : source_id_to_recording_id [s .to_id ()],
372+ 'offsets' : o ,
373+ }
374+ for s , o , _ in zip (* result )
375+ ]
376+ embeddings_batch = np .array (result [2 ])
377+ self .db .insert_windows_batch (
378+ windows_batch ,
379+ embeddings_batch ,
380+ handle_duplicates = handle_duplicates ,
381+ )
383382
384383 # Commit all changes for windows to the database.
385384 self .db .commit ()
0 commit comments