Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitattributes
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
* text=auto
*.cmd text eol=crlf
*.bat text eol=crlf
*.bonsai text
*.bonsai textuv.lock merge=ours
5 changes: 5 additions & 0 deletions examples/clabe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,8 @@ default_behavior_picker:
curriculum:
script: "curriculum run"
project_directory: "./src/aind_behavior_vr_foraging/curricula/aind.Behavior.VrForaging.Curricula"

dataverse:
tenant_id: "32669cd6-737f-4b39-8bdd-d6951120d3fc"
client_id: "df37356e-3316-484a-b732-319b6b4ad464"
org: "org5d93e08d"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ data = ["contraqctor<0.5.0"]

launcher = [
"aind-clabe[aind-services]",
"aind-data-schema>2",
"aind-data-schema>=2",
"aind_behavior_vr_foraging[data]",
]

Expand Down
2 changes: 1 addition & 1 deletion src/Extensions/AindBehaviorVrForaging.cs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ private static double ValidateSamples(double[] drawnSamples, double preTruncated

private void ValidateTruncationParameters(TruncationParameters truncationParameters)
{
if (truncationParameters.Min >= truncationParameters.Max)
if (truncationParameters.Min > truncationParameters.Max)
{
throw new ArgumentException("Invalid truncation parameters. Min must be lower than Max");
}
Expand Down
2 changes: 2 additions & 0 deletions src/aind_behavior_vr_foraging/data_contract/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

def _dataset_lookup_helper(version: str) -> t.Callable[[Path], contraqctor.contract.Dataset]:
parsed_version = semver.Version.parse(version)
# Ignore release candidate suffix for version comparison
parsed_version = semver.Version(parsed_version.major, parsed_version.minor, parsed_version.patch)
if semver.Version.parse("0.4.0") <= parsed_version < semver.Version.parse("0.5.0"):
from .v0_4_0 import dataset as _dataset
elif semver.Version.parse("0.5.0") <= parsed_version < semver.Version.parse("0.6.0"):
Expand Down
7 changes: 7 additions & 0 deletions src/aind_behavior_vr_foraging/data_contract/v0_6_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,13 @@ def dataset(
root_path / "behavior/SoftwareEvents/PatchTerminationEvent.json"
),
),
SoftwareEvents(
name="SpoutParkingPositions",
description="Encodes the spout parking positions to use while fully retracted or extended.",
reader_params=SoftwareEvents.make_params(
root_path / "behavior/SoftwareEvents/SpoutParkingPositions.json"
),
),
],
),
DataStreamCollection(
Expand Down
11 changes: 11 additions & 0 deletions src/aind_behavior_vr_foraging/data_mappers/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,15 @@ def _get_stimulus_epochs(self) -> List[acquisition.StimulusEpoch]:
# logger.error("Olfactometer device not found in rig model.")
# raise ValueError("Olfactometer device not found in rig model.")

if self.curriculum is not None:
performance_metrics = acquisition.PerformanceMetrics(
output_parameters=acquisition.GenericModel.model_validate(self.curriculum.metrics.model_dump())
)
curriculum_status = str(self.curriculum.trainer_state.is_on_curriculum)
else:
curriculum_status = "false"
performance_metrics = None

stimulus_epochs: list[acquisition.StimulusEpoch] = [
acquisition.StimulusEpoch(
active_devices=active_devices,
Expand All @@ -238,6 +247,8 @@ def _get_stimulus_epochs(self) -> List[acquisition.StimulusEpoch]:
configurations=stimulus_epoch_configurations,
stimulus_name=self.session_model.experiment,
stimulus_modalities=stimulus_modalities,
performance_metrics=performance_metrics,
curriculum_status=curriculum_status,
)
]
return stimulus_epochs
Expand Down
2 changes: 1 addition & 1 deletion src/aind_behavior_vr_foraging/data_mappers/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _mapper(
c = water_calibration.output
if c is None:
c = water_calibration.input.calibrate_output()
assert c.interval_average is not None
c.interval_average = c.interval_average or {}

return measurements.VolumeCalibration(
device_name=device_name,
Expand Down
78 changes: 63 additions & 15 deletions src/aind_behavior_vr_foraging/launcher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from pathlib import Path
from typing import Any
from typing import Any, cast

from aind_behavior_curriculum import TrainerState
from aind_behavior_services.calibration.aind_manipulator import ManipulatorPosition
from aind_behavior_services.session import AindBehaviorSessionModel
from clabe import resource_monitor
from clabe.apps import (
Expand All @@ -13,15 +14,17 @@
)
from clabe.data_transfer.aind_watchdog import WatchdogDataTransferService, WatchdogSettings
from clabe.launcher import (
DefaultBehaviorPicker,
DefaultBehaviorPickerSettings,
Launcher,
LauncherCliArgs,
Promise,
run_if,
)
from clabe.pickers import DefaultBehaviorPickerSettings
from clabe.pickers.dataverse import DataversePicker
from contraqctor.contract.json import SoftwareEvents
from pydantic_settings import CliApp

from . import data_contract
from .data_mappers import AindRigDataMapper, AindSessionDataMapper
from .data_mappers._utils import write_ads_mappers
from .rig import AindVrForagingRig
Expand All @@ -37,43 +40,47 @@ def make_launcher(settings: LauncherCliArgs) -> Launcher:
bonsai_app = AindBehaviorServicesBonsaiApp(BonsaiAppSettings(workflow=Path(r"./src/main.bonsai")))
trainer = CurriculumApp(settings=CurriculumSettings())
watchdog_settings = WatchdogSettings() # type: ignore[call-arg]
picker = DefaultBehaviorPicker[AindVrForagingRig, AindBehaviorSessionModel, AindVrForagingTaskLogic](
settings=DefaultBehaviorPickerSettings() # type: ignore[call-arg]
picker = DataversePicker[AindVrForagingRig, AindBehaviorSessionModel, AindVrForagingTaskLogic](
settings=DefaultBehaviorPickerSettings()
)
launcher = Launcher(
rig=AindVrForagingRig,
session=AindBehaviorSessionModel,
task_logic=AindVrForagingTaskLogic,
settings=settings,
)
manipulator_modifier = ByAnimalManipulatorModifier(picker)

# Get user input
launcher.register_callable(
[
picker.initialize,
picker.pick_session,
picker.pick_trainer_state,
picker.pick_rig,
]
)
launcher.register_callable(picker.initialize)
launcher.register_callable(picker.pick_session)
launcher.register_callable(picker.pick_rig)
launcher.register_callable(manipulator_modifier.inject)
launcher.register_callable(picker.pick_trainer_state)

# Check resources
launcher.register_callable(monitor.build_runner())

# Run the task via Bonsai
launcher.register_callable(bonsai_app.build_runner(allow_std_error=True))

# Update manipulator initial position for next session
launcher.register_callable(manipulator_modifier.dump)

# Curriculum
suggestion = launcher.register_callable(
run_if(lambda: trainer_state_exists_predicate(picker.trainer_state))(
trainer.build_runner(input_trainer_state=Promise(lambda x: picker.trainer_state))
trainer.build_runner(input_trainer_state=lambda: picker.trainer_state)
)
)
launcher.register_callable(
run_if(lambda: suggestion.result is not None)(lambda launcher: _dump_suggestion(launcher, suggestion))
)

launcher.register_callable(
run_if(lambda: suggestion.result is not None)(lambda launcher: picker.dump_model(launcher, suggestion.result))
run_if(lambda: suggestion.result is not None)(
lambda launcher: picker.push_new_suggestion(launcher, suggestion.result.trainer_state)
)
)

# Mappers
Expand All @@ -97,6 +104,47 @@ def _dump_suggestion(launcher: Launcher[Any, Any, Any], suggestion: Promise[Any,
f.write(suggestion.result.model_dump_json(indent=2))


class ByAnimalManipulatorModifier:
def __init__(self, picker: DataversePicker) -> None:
self._picker = picker

def inject(self, launcher: Launcher[AindVrForagingRig, Any, Any]) -> None:
rig = launcher.get_rig(strict=True)
if launcher.subject is None:
raise ValueError("Launcher subject is not defined!")
target_folder = self._picker.subject_dir / launcher.subject
target_file = target_folder / "manipulator_init.json"
if not target_file.exists():
launcher.logger.warning(f"Manipulator initial position file not found: {target_file}. Using default.")
return
else:
cached = ManipulatorPosition.model_validate_json(target_file.read_text(encoding="utf-8"))
launcher.logger.info(f"Loading manipulator initial position from: {target_file}. Deserialized: {cached}")
assert rig.manipulator.calibration is not None
rig.manipulator.calibration.input.initial_position = cached
launcher.set_rig(rig, force=True)
return

def dump(self, launcher: Launcher[AindVrForagingRig, Any, Any]) -> None:
assert launcher.subject is not None
target_folder = self._picker.subject_dir / launcher.subject
target_file = target_folder / "manipulator_init.json"
_dataset = data_contract.dataset(launcher.session_directory)
try:
manipulator_parking_position: SoftwareEvents = cast(
SoftwareEvents, _dataset["Behavior"]["SoftwareEvents"]["SpoutParkingPositions"].load()
)
data: dict[str, Any] = manipulator_parking_position.data.iloc[0]["data"]["ResetPosition"]
position = ManipulatorPosition.model_validate(data)
except Exception as e:
launcher.logger.error(f"Failed to load manipulator parking position: {e}")
return
else:
launcher.logger.info(f"Saving manipulator initial position to: {target_file}. Serialized: {position}")
target_folder.mkdir(parents=True, exist_ok=True)
target_file.write_text(position.model_dump_json(indent=2), encoding="utf-8")


def trainer_state_exists_predicate(input_trainer_state: TrainerState | Promise[Any, TrainerState]) -> bool:
if isinstance(input_trainer_state, Promise):
input_trainer_state = input_trainer_state.result
Expand Down
29 changes: 21 additions & 8 deletions src/main.bonsai
Original file line number Diff line number Diff line change
Expand Up @@ -3364,6 +3364,17 @@ Item3 as Right)</scr:Expression>
<Expression xsi:type="rx:AsyncSubject">
<Name>SpoutParkingPositions</Name>
</Expression>
<Expression xsi:type="SubscribeSubject">
<Name>SpoutParkingPositions</Name>
</Expression>
<Expression xsi:type="Combinator">
<Combinator xsi:type="p5:CreateSoftwareEvent">
<p5:EventName>SpoutParkingPositions</p5:EventName>
</Combinator>
</Expression>
<Expression xsi:type="MulticastSubject">
<Name>SoftwareEvent</Name>
</Expression>
<Expression xsi:type="SubscribeSubject">
<Name>TaskLogicParameters</Name>
</Expression>
Expand Down Expand Up @@ -3535,22 +3546,24 @@ Item3 as Right)</scr:Expression>
<Edge From="8" To="9" Label="Source1" />
<Edge From="10" To="11" Label="Source1" />
<Edge From="11" To="12" Label="Source1" />
<Edge From="12" To="13" Label="Source1" />
<Edge From="13" To="14" Label="Source1" />
<Edge From="14" To="15" Label="Source1" />
<Edge From="15" To="16" Label="Source1" />
<Edge From="16" To="20" Label="Source1" />
<Edge From="17" To="18" Label="Source1" />
<Edge From="16" To="17" Label="Source1" />
<Edge From="18" To="19" Label="Source1" />
<Edge From="19" To="20" Label="Source2" />
<Edge From="19" To="23" Label="Source1" />
<Edge From="20" To="21" Label="Source1" />
<Edge From="21" To="22" Label="Source1" />
<Edge From="22" To="23" Label="Source1" />
<Edge From="22" To="23" Label="Source2" />
<Edge From="23" To="24" Label="Source1" />
<Edge From="24" To="29" Label="Source1" />
<Edge From="24" To="25" Label="Source1" />
<Edge From="25" To="26" Label="Source1" />
<Edge From="26" To="27" Label="Source1" />
<Edge From="27" To="28" Label="Source1" />
<Edge From="28" To="29" Label="Source2" />
<Edge From="27" To="32" Label="Source1" />
<Edge From="28" To="29" Label="Source1" />
<Edge From="29" To="30" Label="Source1" />
<Edge From="30" To="31" Label="Source1" />
<Edge From="31" To="32" Label="Source2" />
</Edges>
</Workflow>
</Expression>
Expand Down
Loading
Loading