Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions src/aind_behavior_vr_foraging/data_mappers/_rig.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,11 @@
class _DeviceNode:
"""Helper class to keep track of devices, their connections and spawned devices"""

device_name: str
device: devices.Device
connections_from: list[connections.Connection] = dataclasses.field(default_factory=list)
spawned_devices: list[devices.Device] = dataclasses.field(default_factory=list)

@property
def device_name(self) -> str:
return self.device.name

def get_spawned_device(self, name: str) -> devices.Device:
for d in self.spawned_devices:
if d.name == name:
Expand Down Expand Up @@ -165,7 +162,7 @@ def _get_calibrations(rig: AindVrForagingRig) -> list[measurements.Calibration]:
@staticmethod
def _get_harp_behavior_node(rig: AindVrForagingRig) -> _DeviceNode:
_connections: list[connections.Connection] = []
source_device = rig.harp_behavior.name or "harp_behavior"
source_device = validate_name(rig, "harp_behavior")

# Add triggered camera controller
if rig.triggered_camera_controller:
Expand Down Expand Up @@ -239,14 +236,15 @@ def _get_harp_behavior_node(rig: AindVrForagingRig) -> _DeviceNode:
)

return _DeviceNode(
device_name=source_device,
device=_harp_device,
connections_from=_connections,
spawned_devices=[speaker, photodiode, water_valve],
)

@staticmethod
def _get_harp_treadmill_node(rig: AindVrForagingRig) -> _DeviceNode:
source_device = rig.harp_treadmill.name or "harp_treadmill"
source_device = validate_name(rig, "harp_treadmill")

_connections = [
connections.Connection(
Expand Down Expand Up @@ -293,12 +291,16 @@ def _get_harp_treadmill_node(rig: AindVrForagingRig) -> _DeviceNode:
)

return _DeviceNode(
device=_harp_device, connections_from=_connections, spawned_devices=[magnetic_brake, encoder, torque_sensor]
device_name=source_device,
device=_harp_device,
connections_from=_connections,
spawned_devices=[magnetic_brake, encoder, torque_sensor],
)

@staticmethod
def _get_harp_clock_generate_node(rig: AindVrForagingRig, components: list[devices.Device]) -> _DeviceNode:
source_device = rig.harp_clock_generator.name or "harp_clock_generator"
source_device = validate_name(rig, "harp_clock_generator")

harp_devices = [d for d in components if isinstance(d, devices.HarpDevice)]
_connections = [
connections.Connection(
Expand All @@ -321,7 +323,7 @@ def _get_harp_clock_generate_node(rig: AindVrForagingRig, components: list[devic
],
)

return _DeviceNode(device=harp_device, connections_from=_connections)
return _DeviceNode(device_name=source_device, device=harp_device, connections_from=_connections)

@staticmethod
def _get_wheel(
Expand Down Expand Up @@ -367,7 +369,7 @@ def _get_all_components_and_connections(

# Get all other harp devices
harp_lickometer = devices.HarpDevice(
name=rig.harp_lickometer.name or "harp_lickometer",
name=validate_name(rig, "harp_lickometer"),
harp_device_type=devices.HarpDeviceType.LICKETYSPLIT,
manufacturer=devices.Organization.AIND,
is_clock_generator=False,
Expand All @@ -376,7 +378,7 @@ def _get_all_components_and_connections(
if rig.harp_sniff_detector is not None:
_components.append(
devices.HarpDevice(
name=rig.harp_sniff_detector.name or "harp_sniff_detector",
name=validate_name(rig, "harp_sniff_detector"),
harp_device_type=devices.HarpDeviceType.SNIFFDETECTOR,
is_clock_generator=False,
)
Expand All @@ -385,7 +387,7 @@ def _get_all_components_and_connections(
if rig.harp_environment_sensor is not None:
_components.append(
devices.HarpDevice(
name=rig.harp_environment_sensor.name or "harp_environment_sensor",
name=validate_name(rig, "harp_environment_sensor"),
harp_device_type=devices.HarpDeviceType.ENVIRONMENTSENSOR,
is_clock_generator=False,
)
Expand Down Expand Up @@ -695,3 +697,10 @@ def _get_olfactometer_channel(
channel_type=ch_type_to_ch_type[ch.channel_type],
flow_capacity=ch.flow_rate_capacity,
)


def validate_name(obj: object, name: str) -> str:
if hasattr(obj, name):
return name
else:
raise ValueError(f"Model {obj.__class__.__name__} does not contain a field {name}.")
25 changes: 19 additions & 6 deletions src/aind_behavior_vr_foraging/data_mappers/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import sys
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional, Union, cast, get_args

import aind_behavior_services.rig as AbsRig
import git
Expand Down Expand Up @@ -114,23 +114,36 @@ def _get_calibrations(self) -> List[acquisition.CALIBRATIONS]:
calibrations += _get_water_calibration(self.rig_model)
return calibrations

@staticmethod
def _include_device(device: AbsRig.Device) -> bool:
if isinstance(device, AbsRig.visual_stimulation.Screen):
return False
if isinstance(device, AbsRig.cameras.CameraController):
return False
if isinstance(device, get_args(AbsRig.cameras.CameraTypes)):
return cast(AbsRig.cameras.CameraTypes, device).video_writer is not None
return True

def _get_data_streams(self) -> List[acquisition.DataStream]:
assert self.session_end_time is not None, "Session end time is not set."

modalities: list[Modality] = [getattr(Modality, "BEHAVIOR")]
if len(self._get_cameras_config()) > 0:
modalities.append(getattr(Modality, "BEHAVIOR_VIDEOS"))
modalities = list(set(modalities))

active_devices = [
_device[0]
for _device in get_fields_of_type(self.rig_model, AbsRig.Device, stop_recursion_on_type=False)
if _device[0] is not None and self._include_device(_device[1])
]

data_streams: list[acquisition.DataStream] = [
acquisition.DataStream(
stream_start_time=self.session_model.date,
stream_end_time=self.session_end_time,
code=[self._get_bonsai_as_code(), self._get_python_as_code()],
active_devices=[
_device[0]
for _device in get_fields_of_type(self.rig_model, AbsRig.Device, stop_recursion_on_type=False)
if _device[0] is not None
],
active_devices=active_devices,
modalities=modalities,
configurations=self._get_cameras_config(),
notes=self.session_model.notes,
Expand Down
2 changes: 1 addition & 1 deletion src/aind_behavior_vr_foraging/data_mappers/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class TrackedDevices(enum.StrEnum):
TORQUE_SENSOR = "torque_sensor"
ROTARY_ENCODER = "rotary_encoder"
ENCLOSURE = "behavior_enclosure"
MOTORIZED_STAGE = "motorized_stage"
MOTORIZED_STAGE = "manipulator"
LICK_SPOUT = "lick_spout"
SCREEN = "screen"
COMPUTER = "computer"
Expand Down