Skip to content

Commit 71d478b

Browse files
authored
fix: modify instrument merge (#1619)
* fix: raise warnings when users include duplicate components, except for HarpDevice with is_clock_generator True * refactor: fix logic and move combine code into its own function * chore: lint
1 parent cc8886e commit 71d478b

File tree

2 files changed

+169
-3
lines changed

2 files changed

+169
-3
lines changed

src/aind_data_schema/core/instrument.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Core Instrument model"""
22

33
from datetime import date
4+
import logging
45
from typing import List, Literal, Optional
56

67
from aind_data_schema_models.modalities import Modality
@@ -267,6 +268,46 @@ def validate_modality_device_dependencies(self):
267268

268269
return self
269270

271+
def _is_harp_clock_generator(self, component) -> bool:
272+
"""Check if a component is a HarpDevice and a clock generator"""
273+
return (
274+
isinstance(component, HarpDevice)
275+
and hasattr(component, "is_clock_generator")
276+
and component.is_clock_generator is True
277+
)
278+
279+
def _combine_components(self, components: List[Device], other_components: List[Device]) -> List[Device]:
280+
"""Combine components from two instruments, handling duplicates appropriately."""
281+
seen_names = set()
282+
combined_components = []
283+
284+
for component in components:
285+
combined_components.append(component)
286+
seen_names.add(component.name)
287+
288+
for other_component in other_components:
289+
if other_component.name in seen_names:
290+
matching_component = next((c for c in components if c.name == other_component.name), None)
291+
if (
292+
matching_component
293+
and self._is_harp_clock_generator(matching_component)
294+
and self._is_harp_clock_generator(other_component)
295+
):
296+
logging.info(
297+
f"{other_component.name} is a HarpDevice clock generator, "
298+
f"only one instance will be kept in the combined instrument."
299+
)
300+
else:
301+
logging.error(
302+
f"Instruments should not have duplicated components,"
303+
f" this will raise an error in future versions: {other_component.name}"
304+
)
305+
else:
306+
combined_components.append(other_component)
307+
seen_names.add(other_component.name)
308+
309+
return combined_components
310+
270311
def __add__(self, other: "Instrument") -> "Instrument":
271312
"""Combine two Instrument objects"""
272313

@@ -306,7 +347,7 @@ def __add__(self, other: "Instrument") -> "Instrument":
306347
combined_connections = self.connections + other.connections
307348

308349
# Combine components
309-
combined_components = self.components + other.components
350+
combined_components = self._combine_components(self.components, other.components)
310351

311352
# Combine notes
312353
combined_notes = merge_notes(self.notes, other.notes)

tests/test_instrument.py

Lines changed: 127 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import json
44
import unittest
55
from datetime import date
6+
from unittest.mock import patch
67

78
from aind_data_schema_models.coordinates import AnatomicalRelative
89
from aind_data_schema_models.modalities import Modality
910
from aind_data_schema_models.organizations import Organization
1011
from aind_data_schema_models.units import FrequencyUnit, PowerUnit, SizeUnit
12+
from aind_data_schema_models.harp_types import HarpDeviceType
1113
from pydantic import ValidationError
1214

1315
from aind_data_schema.components.coordinates import CoordinateSystemLibrary
@@ -25,6 +27,7 @@
2527
EphysAssembly,
2628
EphysProbe,
2729
FiberPatchCord,
30+
HarpDevice,
2831
Laser,
2932
LaserAssembly,
3033
Lens,
@@ -655,8 +658,8 @@ def test_instrument_addition(self):
655658
# Check that modalities are combined and sorted (should be the same since we're adding identical instruments)
656659
self.assertEqual(len(combined.modalities), len(set(inst1.modalities + inst2.modalities)))
657660

658-
# Check that components are combined
659-
self.assertEqual(len(combined.components), len(inst1.components) + len(inst2.components))
661+
# Check that components are deduplicated (same names from both instruments result in keeping only one)
662+
self.assertEqual(len(combined.components), len(inst1.components))
660663

661664
# Check that connections are combined
662665
self.assertEqual(len(combined.connections), len(inst1.connections) + len(inst2.connections))
@@ -710,6 +713,128 @@ def test_instrument_addition(self):
710713
combined = inst1 + inst2
711714
self.assertEqual(combined.notes, "Only note")
712715

716+
def test_duplicate_non_harp_device_components(self):
717+
"""Test that duplicate non-HarpDevice components log an error when combining instruments"""
718+
719+
inst1 = Instrument(
720+
instrument_id="test_inst",
721+
modification_date=date(2020, 10, 10),
722+
modalities=[Modality.ECEPHYS],
723+
coordinate_system=CoordinateSystemLibrary.BREGMA_ARI,
724+
components=[Computer(name="Computer1")],
725+
)
726+
inst2 = Instrument(
727+
instrument_id="test_inst",
728+
modification_date=date(2020, 10, 10),
729+
modalities=[Modality.ECEPHYS],
730+
coordinate_system=CoordinateSystemLibrary.BREGMA_ARI,
731+
components=[Computer(name="Computer1")],
732+
)
733+
734+
with patch("aind_data_schema.core.instrument.logging") as mock_logging:
735+
combined = inst1 + inst2
736+
mock_logging.error.assert_called_once()
737+
error_call_args = mock_logging.error.call_args[0][0]
738+
self.assertIn("Computer1", error_call_args)
739+
self.assertIn("duplicated", error_call_args)
740+
741+
self.assertEqual(len(combined.components), 1)
742+
743+
def test_duplicate_harp_clock_generator_devices(self):
744+
"""Test that duplicate HarpDevice clock generators are allowed when combining instruments"""
745+
746+
harp_clock_gen = HarpDevice(
747+
name="Harp Clock Generator",
748+
harp_device_type=HarpDeviceType.CLOCKSYNCHRONIZER,
749+
core_version="2.1",
750+
channels=[],
751+
is_clock_generator=True,
752+
)
753+
754+
inst1 = Instrument(
755+
instrument_id="test_inst",
756+
modification_date=date(2020, 10, 10),
757+
modalities=[Modality.ECEPHYS],
758+
coordinate_system=CoordinateSystemLibrary.BREGMA_ARI,
759+
components=[harp_clock_gen],
760+
)
761+
inst2 = Instrument(
762+
instrument_id="test_inst",
763+
modification_date=date(2020, 10, 10),
764+
modalities=[Modality.ECEPHYS],
765+
coordinate_system=CoordinateSystemLibrary.BREGMA_ARI,
766+
components=[harp_clock_gen.model_copy(deep=True)],
767+
)
768+
769+
with patch("aind_data_schema.core.instrument.logging") as mock_logging:
770+
combined = inst1 + inst2
771+
mock_logging.info.assert_called_once()
772+
info_call_args = mock_logging.info.call_args[0][0]
773+
self.assertIn("Harp Clock Generator", info_call_args)
774+
775+
self.assertEqual(len(combined.components), 1)
776+
777+
def test_duplicate_non_harp_device_with_clock_generator_attribute(self):
778+
"""Test that duplicate non-HarpDevice components with is_clock_generator log error"""
779+
780+
harp_clock_gen = HarpDevice(
781+
name="CustomClockGenerator",
782+
harp_device_type=HarpDeviceType.BEHAVIOR,
783+
is_clock_generator=True,
784+
channels=[],
785+
)
786+
787+
harp_non_clock_gen = HarpDevice(
788+
name="CustomClockGenerator",
789+
harp_device_type=HarpDeviceType.BEHAVIOR,
790+
is_clock_generator=False,
791+
channels=[],
792+
)
793+
794+
inst1 = Instrument(
795+
instrument_id="test_inst",
796+
modification_date=date(2020, 10, 10),
797+
modalities=[Modality.BEHAVIOR],
798+
coordinate_system=CoordinateSystemLibrary.BREGMA_ARI,
799+
components=[harp_clock_gen, LickSpoutAssembly(
800+
name="Lick spout assembly A",
801+
lick_spouts=[
802+
LickSpout(
803+
name="Left spout",
804+
spout_diameter=1.2,
805+
solenoid_valve=Device(name="Solenoid Left"),
806+
lick_sensor=Device(name="Lick-o-meter Left"),
807+
),
808+
],
809+
)],
810+
)
811+
inst2 = Instrument(
812+
instrument_id="test_inst",
813+
modification_date=date(2020, 10, 10),
814+
modalities=[Modality.BEHAVIOR],
815+
coordinate_system=CoordinateSystemLibrary.BREGMA_ARI,
816+
components=[harp_non_clock_gen, LickSpoutAssembly(
817+
name="Lick spout assembly B",
818+
lick_spouts=[
819+
LickSpout(
820+
name="Left spout",
821+
spout_diameter=1.2,
822+
solenoid_valve=Device(name="Solenoid Left"),
823+
lick_sensor=Device(name="Lick-o-meter Left"),
824+
),
825+
],
826+
)],
827+
)
828+
829+
with patch("aind_data_schema.core.instrument.logging") as mock_logging:
830+
combined = inst1 + inst2
831+
mock_logging.error.assert_called_once()
832+
error_call_args = mock_logging.error.call_args[0][0]
833+
self.assertIn("CustomClockGenerator", error_call_args)
834+
self.assertIn("duplicated", error_call_args)
835+
836+
self.assertEqual(len(combined.components), 3)
837+
713838

714839
class ConnectionTest(unittest.TestCase):
715840
"""Test the Connection schema"""

0 commit comments

Comments
 (0)