Skip to content

Commit 189fcc4

Browse files
authored
Fix issue with device's names not being aligned across instrument and acquisition mappers (#453)
* Ensure names are aligned between mappers * Fix double import
1 parent 7766fe1 commit 189fcc4

File tree

3 files changed

+41
-19
lines changed

3 files changed

+41
-19
lines changed

src/aind_behavior_vr_foraging/data_mappers/_rig.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,11 @@
2626
class _DeviceNode:
2727
"""Helper class to keep track of devices, their connections and spawned devices"""
2828

29+
device_name: str
2930
device: devices.Device
3031
connections_from: list[connections.Connection] = dataclasses.field(default_factory=list)
3132
spawned_devices: list[devices.Device] = dataclasses.field(default_factory=list)
3233

33-
@property
34-
def device_name(self) -> str:
35-
return self.device.name
36-
3734
def get_spawned_device(self, name: str) -> devices.Device:
3835
for d in self.spawned_devices:
3936
if d.name == name:
@@ -165,7 +162,7 @@ def _get_calibrations(rig: AindVrForagingRig) -> list[measurements.Calibration]:
165162
@staticmethod
166163
def _get_harp_behavior_node(rig: AindVrForagingRig) -> _DeviceNode:
167164
_connections: list[connections.Connection] = []
168-
source_device = rig.harp_behavior.name or "harp_behavior"
165+
source_device = validate_name(rig, "harp_behavior")
169166

170167
# Add triggered camera controller
171168
if rig.triggered_camera_controller:
@@ -239,14 +236,15 @@ def _get_harp_behavior_node(rig: AindVrForagingRig) -> _DeviceNode:
239236
)
240237

241238
return _DeviceNode(
239+
device_name=source_device,
242240
device=_harp_device,
243241
connections_from=_connections,
244242
spawned_devices=[speaker, photodiode, water_valve],
245243
)
246244

247245
@staticmethod
248246
def _get_harp_treadmill_node(rig: AindVrForagingRig) -> _DeviceNode:
249-
source_device = rig.harp_treadmill.name or "harp_treadmill"
247+
source_device = validate_name(rig, "harp_treadmill")
250248

251249
_connections = [
252250
connections.Connection(
@@ -293,12 +291,16 @@ def _get_harp_treadmill_node(rig: AindVrForagingRig) -> _DeviceNode:
293291
)
294292

295293
return _DeviceNode(
296-
device=_harp_device, connections_from=_connections, spawned_devices=[magnetic_brake, encoder, torque_sensor]
294+
device_name=source_device,
295+
device=_harp_device,
296+
connections_from=_connections,
297+
spawned_devices=[magnetic_brake, encoder, torque_sensor],
297298
)
298299

299300
@staticmethod
300301
def _get_harp_clock_generate_node(rig: AindVrForagingRig, components: list[devices.Device]) -> _DeviceNode:
301-
source_device = rig.harp_clock_generator.name or "harp_clock_generator"
302+
source_device = validate_name(rig, "harp_clock_generator")
303+
302304
harp_devices = [d for d in components if isinstance(d, devices.HarpDevice)]
303305
_connections = [
304306
connections.Connection(
@@ -321,7 +323,7 @@ def _get_harp_clock_generate_node(rig: AindVrForagingRig, components: list[devic
321323
],
322324
)
323325

324-
return _DeviceNode(device=harp_device, connections_from=_connections)
326+
return _DeviceNode(device_name=source_device, device=harp_device, connections_from=_connections)
325327

326328
@staticmethod
327329
def _get_wheel(
@@ -367,7 +369,7 @@ def _get_all_components_and_connections(
367369

368370
# Get all other harp devices
369371
harp_lickometer = devices.HarpDevice(
370-
name=rig.harp_lickometer.name or "harp_lickometer",
372+
name=validate_name(rig, "harp_lickometer"),
371373
harp_device_type=devices.HarpDeviceType.LICKETYSPLIT,
372374
manufacturer=devices.Organization.AIND,
373375
is_clock_generator=False,
@@ -376,7 +378,7 @@ def _get_all_components_and_connections(
376378
if rig.harp_sniff_detector is not None:
377379
_components.append(
378380
devices.HarpDevice(
379-
name=rig.harp_sniff_detector.name or "harp_sniff_detector",
381+
name=validate_name(rig, "harp_sniff_detector"),
380382
harp_device_type=devices.HarpDeviceType.SNIFFDETECTOR,
381383
is_clock_generator=False,
382384
)
@@ -385,7 +387,7 @@ def _get_all_components_and_connections(
385387
if rig.harp_environment_sensor is not None:
386388
_components.append(
387389
devices.HarpDevice(
388-
name=rig.harp_environment_sensor.name or "harp_environment_sensor",
390+
name=validate_name(rig, "harp_environment_sensor"),
389391
harp_device_type=devices.HarpDeviceType.ENVIRONMENTSENSOR,
390392
is_clock_generator=False,
391393
)
@@ -695,3 +697,10 @@ def _get_olfactometer_channel(
695697
channel_type=ch_type_to_ch_type[ch.channel_type],
696698
flow_capacity=ch.flow_rate_capacity,
697699
)
700+
701+
702+
def validate_name(obj: object, name: str) -> str:
703+
if hasattr(obj, name):
704+
return name
705+
else:
706+
raise ValueError(f"Model {obj.__class__.__name__} does not contain a field {name}.")

src/aind_behavior_vr_foraging/data_mappers/_session.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import sys
55
from pathlib import Path
6-
from typing import List, Optional, Union
6+
from typing import List, Optional, Union, cast, get_args
77

88
import aind_behavior_services.rig as AbsRig
99
import git
@@ -114,23 +114,36 @@ def _get_calibrations(self) -> List[acquisition.CALIBRATIONS]:
114114
calibrations += _get_water_calibration(self.rig_model)
115115
return calibrations
116116

117+
@staticmethod
118+
def _include_device(device: AbsRig.Device) -> bool:
119+
if isinstance(device, AbsRig.visual_stimulation.Screen):
120+
return False
121+
if isinstance(device, AbsRig.cameras.CameraController):
122+
return False
123+
if isinstance(device, get_args(AbsRig.cameras.CameraTypes)):
124+
return cast(AbsRig.cameras.CameraTypes, device).video_writer is not None
125+
return True
126+
117127
def _get_data_streams(self) -> List[acquisition.DataStream]:
118128
assert self.session_end_time is not None, "Session end time is not set."
119129

120130
modalities: list[Modality] = [getattr(Modality, "BEHAVIOR")]
121131
if len(self._get_cameras_config()) > 0:
122132
modalities.append(getattr(Modality, "BEHAVIOR_VIDEOS"))
123133
modalities = list(set(modalities))
134+
135+
active_devices = [
136+
_device[0]
137+
for _device in get_fields_of_type(self.rig_model, AbsRig.Device, stop_recursion_on_type=False)
138+
if _device[0] is not None and self._include_device(_device[1])
139+
]
140+
124141
data_streams: list[acquisition.DataStream] = [
125142
acquisition.DataStream(
126143
stream_start_time=self.session_model.date,
127144
stream_end_time=self.session_end_time,
128145
code=[self._get_bonsai_as_code(), self._get_python_as_code()],
129-
active_devices=[
130-
_device[0]
131-
for _device in get_fields_of_type(self.rig_model, AbsRig.Device, stop_recursion_on_type=False)
132-
if _device[0] is not None
133-
],
146+
active_devices=active_devices,
134147
modalities=modalities,
135148
configurations=self._get_cameras_config(),
136149
notes=self.session_model.notes,

src/aind_behavior_vr_foraging/data_mappers/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class TrackedDevices(enum.StrEnum):
106106
TORQUE_SENSOR = "torque_sensor"
107107
ROTARY_ENCODER = "rotary_encoder"
108108
ENCLOSURE = "behavior_enclosure"
109-
MOTORIZED_STAGE = "motorized_stage"
109+
MOTORIZED_STAGE = "manipulator"
110110
LICK_SPOUT = "lick_spout"
111111
SCREEN = "screen"
112112
COMPUTER = "computer"

0 commit comments

Comments
 (0)