|
21 | 21 | import dataclasses |
22 | 22 | import datetime as dt |
23 | 23 | import itertools |
24 | | -from typing import Any |
| 24 | +from typing import Any, Literal |
25 | 25 |
|
26 | 26 | from absl import logging |
27 | 27 | from ml_collections import config_dict |
@@ -351,13 +351,22 @@ def insert_window( |
351 | 351 | recording_id: int, |
352 | 352 | offsets: list[float], |
353 | 353 | embedding: np.ndarray | None = None, |
| 354 | + handle_duplicates: Literal[ |
| 355 | + 'allow', 'overwrite', 'skip', 'error' |
| 356 | + ] = 'error', |
354 | 357 | **kwargs: Any, |
355 | 358 | ) -> int: |
356 | 359 | """Insert a window into the database.""" |
357 | 360 |
|
358 | 361 | if recording_id not in self._recordings: |
359 | 362 | raise ValueError(f'Recording id not found: {recording_id}') |
360 | 363 |
|
| 364 | + duplicate_id = self._handle_window_duplicates( |
| 365 | + recording_id, offsets, handle_duplicates |
| 366 | + ) |
| 367 | + if duplicate_id is not None: |
| 368 | + return duplicate_id |
| 369 | + |
361 | 370 | window_id = self._next_window_id |
362 | 371 | self._windows[window_id] = interface.Window( |
363 | 372 | id=window_id, |
@@ -403,29 +412,21 @@ def insert_annotation( |
403 | 412 | label: str, |
404 | 413 | label_type: interface.LabelType, |
405 | 414 | provenance: str, |
406 | | - skip_duplicates: bool = False, |
| 415 | + handle_duplicates: Literal[ |
| 416 | + 'allow', 'overwrite', 'skip', 'error' |
| 417 | + ] = 'error', |
407 | 418 | **kwargs: Any, |
408 | 419 | ) -> int: |
409 | 420 | """Insert an annotation into the database.""" |
410 | 421 |
|
411 | 422 | if recording_id not in self._recordings: |
412 | 423 | raise ValueError(f'Recording id not found: {recording_id}') |
413 | 424 |
|
414 | | - if skip_duplicates: |
415 | | - matches = self.get_all_annotations( |
416 | | - config_dict.create( |
417 | | - eq=dict( |
418 | | - recording_id=recording_id, |
419 | | - label=label, |
420 | | - label_type=label_type, |
421 | | - ), |
422 | | - approx=dict( |
423 | | - offsets=offsets, |
424 | | - ), |
425 | | - ) |
426 | | - ) |
427 | | - if matches: |
428 | | - return matches[0].id |
| 425 | + duplicate_id = self._handle_annotation_duplicates( |
| 426 | + recording_id, offsets, label, label_type, provenance, handle_duplicates |
| 427 | + ) |
| 428 | + if duplicate_id is not None: |
| 429 | + return duplicate_id |
429 | 430 |
|
430 | 431 | annotation_id = self._next_annotation_id |
431 | 432 | self._annotations[annotation_id] = interface.Annotation( |
|
0 commit comments