Skip to content

Commit fe8401c

Browse files
sdenton4copybara-github
authored andcommitted
Use the insert_windows_batch in the EmbedWorker to speed up insertion.
PiperOrigin-RevId: 871916791
1 parent e6b6123 commit fe8401c

File tree

1 file changed

+22
-23
lines changed

1 file changed

+22
-23
lines changed

perch_hoplite/agile/embed.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -79,20 +79,6 @@ def process_source_id(
7979
):
8080
return
8181

82-
embs = db.match_window_ids(
83-
deployments_filter=config_dict.create(
84-
eq=dict(project=source_id.dataset_name)
85-
),
86-
recordings_filter=config_dict.create(eq=dict(filename=source_id.file_id)),
87-
windows_filter=config_dict.create(
88-
approx=dict(
89-
offsets=[source_id.offset_s, source_id.offset_s + window_size_s],
90-
)
91-
),
92-
)
93-
if embs:
94-
return
95-
9682
outputs = worker.embedding_model.embed(audio_array)
9783
logits_key = state['worker'].model_config.logits_key
9884
if logits_key is None:
@@ -307,11 +293,19 @@ def embedding_exists(
307293
)
308294
return bool(embs)
309295

310-
def process_all(self, target_dataset_name: str | None = None, batch_size=32):
296+
def process_all(
297+
self,
298+
target_dataset_name: str | None = None,
299+
batch_size=32,
300+
handle_duplicates='error',
301+
):
311302
"""Process all audio examples."""
312303

313304
# Update model config and audio sources in the database.
314305
self.update_configs()
306+
if self.db.count_embeddings() == 0:
307+
# No chance of duplicates, so we can use "allow" mode.
308+
handle_duplicates = 'allow'
315309

316310
# Create missing deployments and recordings in the database.
317311
source_id_to_deployment_id = {}
@@ -372,14 +366,19 @@ def process_all(self, target_dataset_name: str | None = None, batch_size=32):
372366
for result in got:
373367
if result is None:
374368
continue
375-
for source, offsets, embedding in zip(*result):
376-
source_id = source.to_id()
377-
recording_id = source_id_to_recording_id[source_id]
378-
self.db.insert_window(
379-
recording_id=recording_id,
380-
embedding=embedding,
381-
offsets=offsets,
382-
)
369+
windows_batch = [
370+
{
371+
'recording_id': source_id_to_recording_id[s.to_id()],
372+
'offsets': o,
373+
}
374+
for s, o, _ in zip(*result)
375+
]
376+
embeddings_batch = np.array(result[2])
377+
self.db.insert_windows_batch(
378+
windows_batch,
379+
embeddings_batch,
380+
handle_duplicates=handle_duplicates,
381+
)
383382

384383
# Commit all changes for windows to the database.
385384
self.db.commit()

0 commit comments

Comments
 (0)