Skip to content

Commit 0f22b64

Browse files
authored
fix: recursively get all component names (#1516)
* fix: repair a bug where get_component_names() wasn't being used to recursively get all component names * tests: fix bad tests that weren't resetting properly
1 parent 30b4b6a commit 0f22b64

File tree

4 files changed

+8
-9
lines changed

4 files changed

+8
-9
lines changed

src/aind_data_schema/core/instrument.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,11 @@ class Instrument(DataCoreModel):
165165
description="List of all devices in the instrument",
166166
)
167167

168-
@classmethod
169-
def get_component_names(cls, instrument: "Instrument") -> List[str]:
168+
def get_component_names(self) -> List[str]:
170169
"""Get the name field of all components, recurse into assemblies."""
171170

172171
names = []
173-
for component in instrument.components:
172+
for component in self.components:
174173
names.extend(recursive_get_all_names(component))
175174
names = [name for name in names if name is not None]
176175

@@ -199,7 +198,7 @@ def validate_cameras_other(self):
199198
@classmethod
200199
def validate_connections(cls, self):
201200
"""validate that all connections map between devices that actually exist"""
202-
device_names = Instrument.get_component_names(self)
201+
device_names = self.get_component_names()
203202

204203
for connection in self.connections:
205204
# Check both source and target devices exist

src/aind_data_schema/core/metadata.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,7 @@ def validate_acquisition_active_devices(cls, values):
203203
device_names = []
204204

205205
if values.instrument:
206-
for component in values.instrument.components:
207-
device_names.append(component.name)
206+
device_names.extend(values.instrument.get_component_names())
208207
if values.procedures:
209208
device_names.extend(values.procedures.get_device_names())
210209

@@ -226,8 +225,7 @@ def validate_acquisition_connections(self):
226225
device_names = []
227226

228227
if self.instrument:
229-
for component in self.instrument.components:
230-
device_names.append(component.name)
228+
device_names.extend(self.instrument.get_component_names())
231229
if self.procedures:
232230
device_names.extend(self.procedures.get_device_names())
233231

src/aind_data_schema/utils/compatibility_check.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _compare_stimulus_devices(self) -> Optional[ValueError]:
3131
for stimulus_epoch in getattr(self.acquisition, "stimulus_epochs", [])
3232
for stimulus_device_name in getattr(stimulus_epoch, "active_devices")
3333
]
34-
instrument_component_names = [getattr(comp, "name", None) for comp in getattr(self.inst, "components", [])]
34+
instrument_component_names = self.inst.get_component_names()
3535

3636
if any(device not in instrument_component_names for device in acquisition_stimulus_devices):
3737
return ValueError(

tests/test_compatibility_check.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def setUp(self):
2626
device1.name = "component_2"
2727

2828
self.mock_instrument.components = [device0, device1]
29+
# Mock the get_component_names method to return the expected component names
30+
self.mock_instrument.get_component_names.return_value = ["component_1", "component_2"]
2931

3032
# Mock acquisition attributes
3133
self.mock_acquisition.instrument_id = "instrument_1"

0 commit comments

Comments
 (0)