Skip to content

Commit d367dda

Browse files
stefanistratecopybara-github
authored andcommitted
Rework how duplicate windows & annotations are inserted.
Let the user control what they want to do with duplicates: allow / overwrite / skip / throw an error. Fixes #80 PiperOrigin-RevId: 865360367
1 parent 6d9f817 commit d367dda

File tree

12 files changed

+321
-59
lines changed

12 files changed

+321
-59
lines changed

perch_hoplite/agile/02_agile_modeling.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@
170170
" label=ann.label,\n",
171171
" label_type=ann.label_type,\n",
172172
" provenance=ann.provenance,\n",
173-
" skip_duplicates=True,\n",
173+
" handle_duplicates=\"skip\",\n",
174174
" )\n",
175175
"\n",
176176
"print(\"Annotations after saving new labels:\", len(db.get_all_annotations()))"
@@ -310,7 +310,7 @@
310310
" label=ann.label,\n",
311311
" label_type=ann.label_type,\n",
312312
" provenance=ann.provenance,\n",
313-
" skip_duplicates=True,\n",
313+
" handle_duplicates=\"skip\",\n",
314314
" )\n",
315315
"\n",
316316
"print(\"Annotations after saving new labels:\", len(db.get_all_annotations()))"

perch_hoplite/agile/99_migrate_db.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@
244244
" label=old_label,\n",
245245
" label_type=interface.LabelType(old_type),\n",
246246
" provenance=old_provenance,\n",
247+
" handle_duplicates=\"allow\",\n",
247248
" )\n",
248249
"new_db.commit()"
249250
],

perch_hoplite/agile/ingest_annotations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def ingest_dataset(
9494
label=label,
9595
label_type=interface.LabelType.POSITIVE,
9696
provenance=provenance,
97+
handle_duplicates='allow',
9798
)
9899
lbl_counts[label] += 1
99100
lbl_count = sum(lbl_counts.values())

perch_hoplite/agile/tests/classifier_data_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def add_label(window_id, lbl_idx, lbl_type):
211211
label=db_test_utils.CLASS_LABELS[lbl_idx],
212212
label_type=lbl_type,
213213
provenance='test',
214+
handle_duplicates='allow',
214215
)
215216

216217
with self.subTest('single_positive_label'):

perch_hoplite/db/README.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,20 @@ window_id = db.insert_window(
303303
embedding=np.random.normal(size=1280),
304304
extra=999,
305305
)
306+
window_id = db.insert_window(
307+
recording_id=1,
308+
offsets=[2, 3],
309+
embedding=np.random.normal(size=1280),
310+
extra=1000,
311+
handle_duplicates="overwrite",
312+
)
313+
window_id = db.insert_window(
314+
recording_id=1,
315+
offsets=[2, 3],
316+
embedding=np.random.normal(size=1280),
317+
extra=1001,
318+
handle_duplicates="skip",
319+
)
306320
```
307321

308322
Retrieving an existing window can be done like this:
@@ -373,7 +387,7 @@ annotation_id = db.insert_annotation(
373387
label="wolf",
374388
label_type=LabelType.POSITIVE,
375389
provenance="me",
376-
skip_duplicates=True,
390+
handle_duplicates="skip",
377391
)
378392
```
379393

perch_hoplite/db/db_loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def duplicate_db(
9494
for annotation in tqdm.tqdm(source_db.get_all_annotations()):
9595
target_id = target_db.insert_annotation(
9696
window_id=window_id_mapping[annotation.window_id],
97+
handle_duplicates='allow',
9798
**annotation.to_kwargs(skip=['id', 'window_id']),
9899
)
99100
annotation_id_mapping[annotation.id] = target_id

perch_hoplite/db/in_mem_impl.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import dataclasses
2222
import datetime as dt
2323
import itertools
24-
from typing import Any
24+
from typing import Any, Literal
2525

2626
from absl import logging
2727
from ml_collections import config_dict
@@ -351,13 +351,22 @@ def insert_window(
351351
recording_id: int,
352352
offsets: list[float],
353353
embedding: np.ndarray | None = None,
354+
handle_duplicates: Literal[
355+
'allow', 'overwrite', 'skip', 'error'
356+
] = 'error',
354357
**kwargs: Any,
355358
) -> int:
356359
"""Insert a window into the database."""
357360

358361
if recording_id not in self._recordings:
359362
raise ValueError(f'Recording id not found: {recording_id}')
360363

364+
duplicate_id = self._handle_window_duplicates(
365+
recording_id, offsets, handle_duplicates
366+
)
367+
if duplicate_id is not None:
368+
return duplicate_id
369+
361370
window_id = self._next_window_id
362371
self._windows[window_id] = interface.Window(
363372
id=window_id,
@@ -403,29 +412,21 @@ def insert_annotation(
403412
label: str,
404413
label_type: interface.LabelType,
405414
provenance: str,
406-
skip_duplicates: bool = False,
415+
handle_duplicates: Literal[
416+
'allow', 'overwrite', 'skip', 'error'
417+
] = 'error',
407418
**kwargs: Any,
408419
) -> int:
409420
"""Insert an annotation into the database."""
410421

411422
if recording_id not in self._recordings:
412423
raise ValueError(f'Recording id not found: {recording_id}')
413424

414-
if skip_duplicates:
415-
matches = self.get_all_annotations(
416-
config_dict.create(
417-
eq=dict(
418-
recording_id=recording_id,
419-
label=label,
420-
label_type=label_type,
421-
),
422-
approx=dict(
423-
offsets=offsets,
424-
),
425-
)
426-
)
427-
if matches:
428-
return matches[0].id
425+
duplicate_id = self._handle_annotation_duplicates(
426+
recording_id, offsets, label, label_type, provenance, handle_duplicates
427+
)
428+
if duplicate_id is not None:
429+
return duplicate_id
429430

430431
annotation_id = self._next_annotation_id
431432
self._annotations[annotation_id] = interface.Annotation(

0 commit comments

Comments
 (0)