Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions perch_hoplite/agile/02_agile_modeling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@
" label=ann.label,\n",
" label_type=ann.label_type,\n",
" provenance=ann.provenance,\n",
" skip_duplicates=True,\n",
" handle_duplicates=\"skip\",\n",
" )\n",
"\n",
"print(\"Annotations after saving new labels:\", len(db.get_all_annotations()))"
Expand Down Expand Up @@ -310,7 +310,7 @@
" label=ann.label,\n",
" label_type=ann.label_type,\n",
" provenance=ann.provenance,\n",
" skip_duplicates=True,\n",
" handle_duplicates=\"skip\",\n",
" )\n",
"\n",
"print(\"Annotations after saving new labels:\", len(db.get_all_annotations()))"
Expand Down
1 change: 1 addition & 0 deletions perch_hoplite/agile/99_migrate_db.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@
" label=old_label,\n",
" label_type=interface.LabelType(old_type),\n",
" provenance=old_provenance,\n",
" handle_duplicates=\"allow\",\n",
" )\n",
"new_db.commit()"
],
Expand Down
1 change: 1 addition & 0 deletions perch_hoplite/agile/ingest_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def ingest_dataset(
label=label,
label_type=interface.LabelType.POSITIVE,
provenance=provenance,
handle_duplicates='allow',
)
lbl_counts[label] += 1
lbl_count = sum(lbl_counts.values())
Expand Down
1 change: 1 addition & 0 deletions perch_hoplite/agile/tests/classifier_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def add_label(window_id, lbl_idx, lbl_type):
label=db_test_utils.CLASS_LABELS[lbl_idx],
label_type=lbl_type,
provenance='test',
handle_duplicates='allow',
)

with self.subTest('single_positive_label'):
Expand Down
16 changes: 15 additions & 1 deletion perch_hoplite/db/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,20 @@ window_id = db.insert_window(
embedding=np.random.normal(size=1280),
extra=999,
)
window_id = db.insert_window(
recording_id=1,
offsets=[2, 3],
embedding=np.random.normal(size=1280),
extra=1000,
handle_duplicates="overwrite",
)
window_id = db.insert_window(
recording_id=1,
offsets=[2, 3],
embedding=np.random.normal(size=1280),
extra=1001,
handle_duplicates="skip",
)
```

Retrieving an existing window can be done like this:
Expand Down Expand Up @@ -373,7 +387,7 @@ annotation_id = db.insert_annotation(
label="wolf",
label_type=LabelType.POSITIVE,
provenance="me",
skip_duplicates=True,
handle_duplicates="skip",
)
```

Expand Down
1 change: 1 addition & 0 deletions perch_hoplite/db/db_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def duplicate_db(
for annotation in tqdm.tqdm(source_db.get_all_annotations()):
target_id = target_db.insert_annotation(
window_id=window_id_mapping[annotation.window_id],
handle_duplicates='allow',
**annotation.to_kwargs(skip=['id', 'window_id']),
)
annotation_id_mapping[annotation.id] = target_id
Expand Down
35 changes: 18 additions & 17 deletions perch_hoplite/db/in_mem_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import dataclasses
import datetime as dt
import itertools
from typing import Any
from typing import Any, Literal

from absl import logging
from ml_collections import config_dict
Expand Down Expand Up @@ -351,13 +351,22 @@ def insert_window(
recording_id: int,
offsets: list[float],
embedding: np.ndarray | None = None,
handle_duplicates: Literal[
'allow', 'overwrite', 'skip', 'error'
] = 'error',
**kwargs: Any,
) -> int:
"""Insert a window into the database."""

if recording_id not in self._recordings:
raise ValueError(f'Recording id not found: {recording_id}')

duplicate_id = self._handle_window_duplicates(
recording_id, offsets, handle_duplicates
)
if duplicate_id is not None:
return duplicate_id

window_id = self._next_window_id
self._windows[window_id] = interface.Window(
id=window_id,
Expand Down Expand Up @@ -403,29 +412,21 @@ def insert_annotation(
label: str,
label_type: interface.LabelType,
provenance: str,
skip_duplicates: bool = False,
handle_duplicates: Literal[
'allow', 'overwrite', 'skip', 'error'
] = 'error',
**kwargs: Any,
) -> int:
"""Insert an annotation into the database."""

if recording_id not in self._recordings:
raise ValueError(f'Recording id not found: {recording_id}')

if skip_duplicates:
matches = self.get_all_annotations(
config_dict.create(
eq=dict(
recording_id=recording_id,
label=label,
label_type=label_type,
),
approx=dict(
offsets=offsets,
),
)
)
if matches:
return matches[0].id
duplicate_id = self._handle_annotation_duplicates(
recording_id, offsets, label, label_type, provenance, handle_duplicates
)
if duplicate_id is not None:
return duplicate_id

annotation_id = self._next_annotation_id
self._annotations[annotation_id] = interface.Annotation(
Expand Down
Loading
Loading