diff --git a/examples/task_mcm.py b/examples/task_mcm.py index 820ed5ce..9de7adb5 100644 --- a/examples/task_mcm.py +++ b/examples/task_mcm.py @@ -55,7 +55,7 @@ def compute_cmc_transition_probability(n_states, rep_rate, T=3.5, dt=0.1) -> np. operation_control = vr_task_logic.OperationControl( movable_spout_control=vr_task_logic.MovableSpoutControl(enabled=False), - audio_control=vr_task_logic.AudioControl(duration=0.2, frequency=5000), + audio_control=vr_task_logic.AudioControl(duration=0.2, frequency=9999), odor_control=vr_task_logic.OdorControl(), position_control=vr_task_logic.PositionControl( frequency_filter_cutoff=5, diff --git a/examples/task_patch_foraging.py b/examples/task_patch_foraging.py index 5a71a54f..70582382 100644 --- a/examples/task_patch_foraging.py +++ b/examples/task_patch_foraging.py @@ -108,7 +108,7 @@ def PostPatchVirtualSiteGeneratorHelper(contrast: float = 1, friction: float = 0 rule=vr_task_logic.RewardFunctionRule.ON_REWARD, ) -reset_function = vr_task_logic.OnThisPatchEntryFunction( +reset_function = vr_task_logic.OnThisPatchEntryRewardFunction( available=vr_task_logic.SetValueFunction(value=vr_task_logic.scalar_value(0.1)) ) diff --git a/examples/test_single_site_patch.py b/examples/test_single_site_patch.py new file mode 100644 index 00000000..032d9751 --- /dev/null +++ b/examples/test_single_site_patch.py @@ -0,0 +1,186 @@ +import os +from typing import Optional + +import aind_behavior_services.task_logic.distributions as distributions +from aind_behavior_curriculum import Stage, TrainerState + +import aind_behavior_vr_foraging.task_logic as vr_task_logic +from aind_behavior_vr_foraging.task_logic import ( + AindVrForagingTaskLogic, + AindVrForagingTaskParameters, +) + +MINIMUM_INTERPATCH_LENGTH = 50 +MEAN_INTERPATCH_LENGTH = 150 +MAXIMUM_INTERPATCH_LENGTH = 500 +INTERSITE_LENGTH = 50 +REWARDSITE_LENGTH = 50 +REWARD_AMOUNT = 3 +VELOCITY_THRESHOLD = 15 # cm/s + +P_REWARD_BLOCK: list[tuple[float, Optional[float], Optional[float]]] = [ + (1.0, 1.0, None), + (0.8, 0.8, None), + (0.8, 0.2, None), +] + +P_BAIT_BLOCK = [ + (1.0, 1.0, None), + (0.4, 0.4, None), + (0.4, 0.1, None), +] + + +def make_patch( + label: str, + state_index: int, + odor_index: int, + p_reward: float, + p_replenish: float, +): + baiting_function = vr_task_logic.PersistentRewardFunction( + rule=vr_task_logic.RewardFunctionRule.ON_PATCH_ENTRY, + probability=vr_task_logic.SetValueFunction( + value=distributions.BinomialDistribution( + distribution_parameters=distributions.BinomialDistributionParameters(n=1, p=p_replenish), + scaling_parameters=distributions.ScalingParameters(offset=p_reward), + truncation_parameters=distributions.TruncationParameters(min=p_reward, max=1), + ), + ), + ) + + depletion_function = vr_task_logic.PatchRewardFunction( + probability=vr_task_logic.SetValueFunction( + value=vr_task_logic.scalar_value(p_reward), + ), + rule=vr_task_logic.RewardFunctionRule.ON_REWARD, + ) + + return vr_task_logic.Patch( + label=label, + state_index=state_index, + odor_specification=vr_task_logic.OdorSpecification(index=odor_index, concentration=1), + patch_terminators=[ + vr_task_logic.PatchTerminatorOnChoice(count=vr_task_logic.scalar_value(1)), + vr_task_logic.PatchTerminatorOnRejection(count=vr_task_logic.scalar_value(1)), + ], + reward_specification=vr_task_logic.RewardSpecification( + amount=vr_task_logic.scalar_value(REWARD_AMOUNT), + probability=vr_task_logic.scalar_value(p_reward), + available=vr_task_logic.scalar_value(999999), + delay=vr_task_logic.scalar_value(0.5), + operant_logic=vr_task_logic.OperantLogic( + is_operant=False, + stop_duration=0.5, + time_to_collect_reward=100000, + grace_distance_threshold=10, + ), + reward_function=[baiting_function, depletion_function], + ), + patch_virtual_sites_generator=vr_task_logic.PatchVirtualSitesGenerator( + inter_patch=vr_task_logic.VirtualSiteGenerator( + render_specification=vr_task_logic.RenderSpecification(contrast=1), + label=vr_task_logic.VirtualSiteLabels.INTERPATCH, + length_distribution=distributions.ExponentialDistribution( + distribution_parameters=distributions.ExponentialDistributionParameters( + rate=1 / MEAN_INTERPATCH_LENGTH + ), + scaling_parameters=distributions.ScalingParameters(offset=MINIMUM_INTERPATCH_LENGTH), + truncation_parameters=distributions.TruncationParameters( + min=MINIMUM_INTERPATCH_LENGTH, + max=MAXIMUM_INTERPATCH_LENGTH, + ), + ), + ), + inter_site=vr_task_logic.VirtualSiteGenerator( + render_specification=vr_task_logic.RenderSpecification(contrast=0.5), + label=vr_task_logic.VirtualSiteLabels.INTERSITE, + length_distribution=vr_task_logic.scalar_value(INTERSITE_LENGTH), + ), + reward_site=vr_task_logic.VirtualSiteGenerator( + render_specification=vr_task_logic.RenderSpecification(contrast=0.5), + label=vr_task_logic.VirtualSiteLabels.REWARDSITE, + length_distribution=vr_task_logic.scalar_value(REWARDSITE_LENGTH), + ), + ), + ) + + +def make_block( + p_rew: tuple[float, Optional[float], Optional[float]], + p_replenish: tuple[float, Optional[float], Optional[float]], + n_min_trials: int = 100, +) -> vr_task_logic.Block: + patches = [make_patch(label="OdorA", state_index=0, odor_index=0, p_reward=p_rew[0], p_replenish=p_replenish[0])] + if p_rew[1] is not None: + assert p_replenish[1] is not None + patches.append( + make_patch(label="OdorB", state_index=1, odor_index=1, p_reward=p_rew[1], p_replenish=p_replenish[1]) + ) + if p_rew[2] is not None: + assert p_replenish[2] is not None + patches.append( + make_patch(label="OdorC", state_index=2, odor_index=2, p_reward=p_rew[2], p_replenish=p_replenish[2]) + ) + + per_p = 1.0 / len(patches) + return vr_task_logic.Block( + environment_statistics=vr_task_logic.EnvironmentStatistics( + first_state_occupancy=[per_p] * len(patches), + transition_matrix=[[per_p] * len(patches) for _ in range(len(patches))], + patches=patches, + ), + end_conditions=[ + vr_task_logic.BlockEndConditionPatchCount( + value=distributions.ExponentialDistribution( + distribution_parameters=distributions.ExponentialDistributionParameters(rate=1 / 25), + scaling_parameters=distributions.ScalingParameters(offset=n_min_trials), + truncation_parameters=distributions.TruncationParameters(min=n_min_trials, max=n_min_trials + 50), + ) + ) + ], + ) + + +operation_control = vr_task_logic.OperationControl( + movable_spout_control=vr_task_logic.MovableSpoutControl(enabled=False), + audio_control=vr_task_logic.AudioControl(duration=0.2, frequency=9999), + odor_control=vr_task_logic.OdorControl(), + position_control=vr_task_logic.PositionControl( + frequency_filter_cutoff=5, + velocity_threshold=VELOCITY_THRESHOLD, + ), +) + + +task_logic = AindVrForagingTaskLogic( + task_parameters=AindVrForagingTaskParameters( + rng_seed=None, + environment=vr_task_logic.BlockStructure( + blocks=[ + make_block(p_rew=P_REWARD_BLOCK[i], p_replenish=P_BAIT_BLOCK[i], n_min_trials=100) + for i in range(len(P_REWARD_BLOCK)) + ], + sampling_mode="Sequential", + ), + operation_control=operation_control, + ), + stage_name="single_site_patch", +) + + +def main(path_seed: str = "./local/SingleSitePatch_{schema}.json"): + example_task_logic = task_logic + example_trainer_state = TrainerState( + stage=Stage(name="example_stage", task=example_task_logic), curriculum=None, is_on_curriculum=False + ) + os.makedirs(os.path.dirname(path_seed), exist_ok=True) + models = [example_task_logic, example_trainer_state] + + for model in models: + with open(path_seed.format(schema=model.__class__.__name__), "w", encoding="utf-8") as f: + f.write(model.model_dump_json(indent=2)) + + +if __name__ == "__main__": + main() diff --git a/src/DataSchemas/aind_behavior_vr_foraging.json b/src/DataSchemas/aind_behavior_vr_foraging.json index fecb3515..680eb270 100644 --- a/src/DataSchemas/aind_behavior_vr_foraging.json +++ b/src/DataSchemas/aind_behavior_vr_foraging.json @@ -2927,12 +2927,12 @@ "title": "OlfactometerChannelType", "type": "string" }, - "OnThisPatchEntryFunction": { + "OnThisPatchEntryRewardFunction": { "description": "A RewardFunction that is applied when the animal enters the patch.", "properties": { "function_type": { - "const": "OnThisPatchEntryFunction", - "default": "OnThisPatchEntryFunction", + "const": "OnThisPatchEntryRewardFunction", + "default": "OnThisPatchEntryRewardFunction", "title": "Function Type", "type": "string" }, @@ -2980,7 +2980,7 @@ "type": "string" } }, - "title": "OnThisPatchEntryFunction", + "title": "OnThisPatchEntryRewardFunction", "type": "object" }, "OperantLogic": { @@ -3694,6 +3694,73 @@ "title": "PdfDistributionParameters", "type": "object" }, + "PersistentRewardFunction": { + "description": "A RewardFunction that is always active.", + "properties": { + "function_type": { + "const": "PersistentRewardFunction", + "default": "PersistentRewardFunction", + "title": "Function Type", + "type": "string" + }, + "amount": { + "default": null, + "description": "Defines the amount of reward replenished per rule unit.", + "oneOf": [ + { + "$ref": "#/$defs/PatchUpdateFunction" + }, + { + "type": "null" + } + ] + }, + "probability": { + "default": null, + "description": "Defines the probability of reward replenished per rule unit.", + "oneOf": [ + { + "$ref": "#/$defs/PatchUpdateFunction" + }, + { + "type": "null" + } + ] + }, + "available": { + "default": null, + "description": "Defines the amount of reward available replenished in the patch per rule unit.", + "oneOf": [ + { + "$ref": "#/$defs/PatchUpdateFunction" + }, + { + "type": "null" + } + ] + }, + "rule": { + "enum": [ + "OnReward", + "OnChoice", + "OnTime", + "OnDistance", + "OnChoiceAccumulated", + "OnRewardAccumulated", + "OnTimeAccumulated", + "OnDistanceAccumulated", + "OnPatchEntry" + ], + "title": "Rule", + "type": "string" + } + }, + "required": [ + "rule" + ], + "title": "PersistentRewardFunction", + "type": "object" + }, "PoissonDistribution": { "properties": { "family": { @@ -3856,9 +3923,10 @@ "RewardFunction": { "discriminator": { "mapping": { - "OnThisPatchEntryFunction": "#/$defs/OnThisPatchEntryFunction", + "OnThisPatchEntryRewardFunction": "#/$defs/OnThisPatchEntryRewardFunction", "OutsideRewardFunction": "#/$defs/OutsideRewardFunction", - "PatchRewardFunction": "#/$defs/PatchRewardFunction" + "PatchRewardFunction": "#/$defs/PatchRewardFunction", + "PersistentRewardFunction": "#/$defs/PersistentRewardFunction" }, "propertyName": "function_type" }, @@ -3870,7 +3938,10 @@ "$ref": "#/$defs/OutsideRewardFunction" }, { - "$ref": "#/$defs/OnThisPatchEntryFunction" + "$ref": "#/$defs/OnThisPatchEntryRewardFunction" + }, + { + "$ref": "#/$defs/PersistentRewardFunction" } ] }, diff --git a/src/Extensions/AindBehaviorVrForaging.Generated.cs b/src/Extensions/AindBehaviorVrForaging.Generated.cs index baa1a876..b92a0cc4 100644 --- a/src/Extensions/AindBehaviorVrForaging.Generated.cs +++ b/src/Extensions/AindBehaviorVrForaging.Generated.cs @@ -8104,7 +8104,7 @@ public enum OlfactometerChannelType [System.ComponentModel.DescriptionAttribute("A RewardFunction that is applied when the animal enters the patch.")] [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)] [Bonsai.CombinatorAttribute(MethodName="Generate")] - public partial class OnThisPatchEntryFunction : RewardFunction + public partial class OnThisPatchEntryRewardFunction : RewardFunction { private PatchUpdateFunction _amount; @@ -8115,12 +8115,12 @@ public partial class OnThisPatchEntryFunction : RewardFunction private string _rule; - public OnThisPatchEntryFunction() + public OnThisPatchEntryRewardFunction() { _rule = "OnThisPatchEntry"; } - protected OnThisPatchEntryFunction(OnThisPatchEntryFunction other) : + protected OnThisPatchEntryRewardFunction(OnThisPatchEntryRewardFunction other) : base(other) { _amount = other._amount; @@ -8200,14 +8200,14 @@ public string Rule } } - public System.IObservable Generate() + public System.IObservable Generate() { - return System.Reactive.Linq.Observable.Defer(() => System.Reactive.Linq.Observable.Return(new OnThisPatchEntryFunction(this))); + return System.Reactive.Linq.Observable.Defer(() => System.Reactive.Linq.Observable.Return(new OnThisPatchEntryRewardFunction(this))); } - public System.IObservable Generate(System.IObservable source) + public System.IObservable Generate(System.IObservable source) { - return System.Reactive.Linq.Observable.Select(source, _ => new OnThisPatchEntryFunction(this)); + return System.Reactive.Linq.Observable.Select(source, _ => new OnThisPatchEntryRewardFunction(this)); } protected override bool PrintMembers(System.Text.StringBuilder stringBuilder) @@ -9771,6 +9771,130 @@ public override string ToString() } + /// + /// A RewardFunction that is always active. + /// + [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.6.1.0 (Newtonsoft.Json v13.0.0.0)")] + [System.ComponentModel.DescriptionAttribute("A RewardFunction that is always active.")] + [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)] + [Bonsai.CombinatorAttribute(MethodName="Generate")] + public partial class PersistentRewardFunction : RewardFunction + { + + private PatchUpdateFunction _amount; + + private PatchUpdateFunction _probability; + + private PatchUpdateFunction _available; + + private PersistentRewardFunctionRule _rule; + + public PersistentRewardFunction() + { + } + + protected PersistentRewardFunction(PersistentRewardFunction other) : + base(other) + { + _amount = other._amount; + _probability = other._probability; + _available = other._available; + _rule = other._rule; + } + + /// + /// Defines the amount of reward replenished per rule unit. + /// + [System.Xml.Serialization.XmlIgnoreAttribute()] + [Newtonsoft.Json.JsonPropertyAttribute("amount")] + [System.ComponentModel.DescriptionAttribute("Defines the amount of reward replenished per rule unit.")] + public PatchUpdateFunction Amount + { + get + { + return _amount; + } + set + { + _amount = value; + } + } + + /// + /// Defines the probability of reward replenished per rule unit. + /// + [System.Xml.Serialization.XmlIgnoreAttribute()] + [Newtonsoft.Json.JsonPropertyAttribute("probability")] + [System.ComponentModel.DescriptionAttribute("Defines the probability of reward replenished per rule unit.")] + public PatchUpdateFunction Probability + { + get + { + return _probability; + } + set + { + _probability = value; + } + } + + /// + /// Defines the amount of reward available replenished in the patch per rule unit. + /// + [System.Xml.Serialization.XmlIgnoreAttribute()] + [Newtonsoft.Json.JsonPropertyAttribute("available")] + [System.ComponentModel.DescriptionAttribute("Defines the amount of reward available replenished in the patch per rule unit.")] + public PatchUpdateFunction Available + { + get + { + return _available; + } + set + { + _available = value; + } + } + + [System.Xml.Serialization.XmlIgnoreAttribute()] + [Newtonsoft.Json.JsonPropertyAttribute("rule", Required=Newtonsoft.Json.Required.Always)] + public PersistentRewardFunctionRule Rule + { + get + { + return _rule; + } + set + { + _rule = value; + } + } + + public System.IObservable Generate() + { + return System.Reactive.Linq.Observable.Defer(() => System.Reactive.Linq.Observable.Return(new PersistentRewardFunction(this))); + } + + public System.IObservable Generate(System.IObservable source) + { + return System.Reactive.Linq.Observable.Select(source, _ => new PersistentRewardFunction(this)); + } + + protected override bool PrintMembers(System.Text.StringBuilder stringBuilder) + { + if (base.PrintMembers(stringBuilder)) + { + stringBuilder.Append(", "); + } + stringBuilder.Append("Amount = " + _amount + ", "); + stringBuilder.Append("Probability = " + _probability + ", "); + stringBuilder.Append("Available = " + _available + ", "); + stringBuilder.Append("Rule = " + _rule); + return true; + } + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.6.1.0 (Newtonsoft.Json v13.0.0.0)")] [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)] [Bonsai.CombinatorAttribute(MethodName="Generate")] @@ -10316,8 +10440,9 @@ public override string ToString() [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.6.1.0 (Newtonsoft.Json v13.0.0.0)")] [Newtonsoft.Json.JsonConverter(typeof(JsonInheritanceConverter), "function_type")] [JsonInheritanceAttribute("PatchRewardFunction", typeof(PatchRewardFunction))] - [JsonInheritanceAttribute("OnThisPatchEntryFunction", typeof(OnThisPatchEntryFunction))] + [JsonInheritanceAttribute("OnThisPatchEntryRewardFunction", typeof(OnThisPatchEntryRewardFunction))] [JsonInheritanceAttribute("OutsideRewardFunction", typeof(OutsideRewardFunction))] + [JsonInheritanceAttribute("PersistentRewardFunction", typeof(PersistentRewardFunction))] [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)] [Bonsai.CombinatorAttribute(MethodName="Generate")] public partial class RewardFunction @@ -15357,6 +15482,40 @@ public enum PatchRewardFunctionRule } + [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.6.1.0 (Newtonsoft.Json v13.0.0.0)")] + [Newtonsoft.Json.JsonConverter(typeof(Newtonsoft.Json.Converters.StringEnumConverter))] + public enum PersistentRewardFunctionRule + { + + [System.Runtime.Serialization.EnumMemberAttribute(Value="OnReward")] + OnReward = 0, + + [System.Runtime.Serialization.EnumMemberAttribute(Value="OnChoice")] + OnChoice = 1, + + [System.Runtime.Serialization.EnumMemberAttribute(Value="OnTime")] + OnTime = 2, + + [System.Runtime.Serialization.EnumMemberAttribute(Value="OnDistance")] + OnDistance = 3, + + [System.Runtime.Serialization.EnumMemberAttribute(Value="OnChoiceAccumulated")] + OnChoiceAccumulated = 4, + + [System.Runtime.Serialization.EnumMemberAttribute(Value="OnRewardAccumulated")] + OnRewardAccumulated = 5, + + [System.Runtime.Serialization.EnumMemberAttribute(Value="OnTimeAccumulated")] + OnTimeAccumulated = 6, + + [System.Runtime.Serialization.EnumMemberAttribute(Value="OnDistanceAccumulated")] + OnDistanceAccumulated = 7, + + [System.Runtime.Serialization.EnumMemberAttribute(Value="OnPatchEntry")] + OnPatchEntry = 8, + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.6.1.0 (Newtonsoft.Json v13.0.0.0)")] [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)] [Bonsai.CombinatorAttribute(MethodName="Generate")] @@ -15780,8 +15939,9 @@ private static System.IObservable Process(System.IObservable

))] - [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] + [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] + [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] public partial class MatchRewardFunction : Bonsai.Expressions.SingleArgumentExpressionBuilder { @@ -16465,9 +16625,9 @@ public System.IObservable Process(System.IObservable(source); } - public System.IObservable Process(System.IObservable source) + public System.IObservable Process(System.IObservable source) { - return Process(source); + return Process(source); } public System.IObservable Process(System.IObservable source) @@ -16545,6 +16705,11 @@ public System.IObservable Process(System.IObservable(source); } + public System.IObservable Process(System.IObservable source) + { + return Process(source); + } + public System.IObservable Process(System.IObservable source) { return Process(source); @@ -16847,7 +17012,7 @@ public System.IObservable Process(System.IObservable [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] - [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] + [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] @@ -16863,6 +17028,7 @@ public System.IObservable Process(System.IObservable [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] + [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] diff --git a/src/Extensions/PatchManagement.bonsai b/src/Extensions/PatchManagement.bonsai index cf826477..61b2aa1f 100644 --- a/src/Extensions/PatchManagement.bonsai +++ b/src/Extensions/PatchManagement.bonsai @@ -1830,7 +1830,7 @@ - + OnThisPatchEntryFunction @@ -1957,6 +1957,1020 @@ + + Persistent + + + + + true + + + + + 1 + + + + IsActive + + + IsActive + + + + + + + + + Source1 + + + + + + + + + + GenerateTick + + + + ThisPatch + + + + 1 + + + + Value + + + RewardSpecification + + + RewardFunction + + + + + + + + + OnChoice + + + + Source1 + + + Rule + + + + + + + OnChoice + + + + + + + + + + + + + + OnChoice + + + + Source1 + + + + 1 + + + + RewardFunction + + + HotChoiceFeedback + + + + 1 + + + + RewardFunction + + + + + + + + + + + + + + + + + + OnReward + + + + Source1 + + + Rule + + + + + + + OnReward + + + + + + + + + + + + + + OnReward + + + + Source1 + + + + 1 + + + + RewardFunction + + + HotGiveReward + + + HasValue + + + + Source1 + + + HasValue + + + + + + + + + + + + 1 + + + + RewardFunction + + + + + + + + + + + + + + + + + + + OnTime + + + + Source1 + + + Rule + + + + + + + OnTime + + + + + + + + + + + + + + OnTime + + + + Source1 + + + + 1 + + + + RewardFunction + + + + + + TimeStep + + + ElapsedRealTime + + + + RefreshEnvironmentRate + + + + + + + + + PT0S + + + + + + + + 1 + + + + RewardFunction + + + + + + + + + + + + + + + + + + + + + + + + + OnDistance + + + + Source1 + + + Rule + + + + + + + OnDistance + + + + + + + + + + + + + + OnDistance + + + + Source1 + + + + 1 + + + + RewardFunction + + + CurrentPosition + + + Value.Z + + + RefreshEnvironmentRate + + + + + + + + + PT0S + PT1S + + + + + + + + 1 + + + + + + + RewardFunction + + + + + + + + + + + + + + + + + + + + + + + + OnChoiceAccumulated + + + + Source1 + + + Rule + + + + + + + OnChoiceAccumulated + + + + + + + + + + + + + + OnChoiceAccumulated + + + + Source1 + + + + 1 + + + + RewardFunction + + + HotChoiceFeedback + + + + 1 + + + + + RewardFunction + + + + + + + + + + + + + + + + + + + OnRewardAccumulated + + + + Source1 + + + Rule + + + + + + + OnRewardAccumulated + + + + + + + + + + + + + + OnRewardAccumulated + + + + Source1 + + + + 1 + + + + RewardFunction + + + HotGiveReward + + + HasValue + + + + Source1 + + + HasValue + + + + + + + + + + + + 1 + + + + + RewardFunction + + + + + + + + + + + + + + + + + + + + OnTimeAccumulated + + + + Source1 + + + Rule + + + + + + + OnTimeAccumulated + + + + + + + + + + + + + + OnTimeAccumulated + + + + Source1 + + + + 1 + + + + RewardFunction + + + + + + TimeStep + + + ElapsedRealTime + + + + RefreshEnvironmentRate + + + + + + + + + PT0S + + + + + + + RewardFunction + + + + + + + + + + + + + + + + + + + + + + + + OnDistanceAccumulated + + + + Source1 + + + Rule + + + + + + + OnDistanceAccumulated + + + + + + + + + + + + + + OnDistanceAccumulated + + + + Source1 + + + + 1 + + + + RewardFunction + + + CurrentPosition + + + Value.Z + + + + 1 + + + + InitialPosition + + + CurrentPosition + + + Value.Z + + + RefreshEnvironmentRate + + + + + + + + + PT0S + PT1S + + + + + + + InitialPosition + + + + + + + toDouble + double(it) + + + RewardFunction + + + + + + + + + + + + + + + + + + + + + + + + + + + + + OnPatchEntry + + + + Source1 + + + Rule + + + + + + + OnPatchEntry + + + + + + + + + + + + + + OnPatchEntry + + + + Source1 + + + + 1 + + + + RewardFunction + + + CurrentPosition + + + Value.Z + + + + 1 + + + + InitialPosition + + + CurrentPosition + + + Value.Z + + + RefreshEnvironmentRate + + + + + + + + + PT0S + PT1S + + + + + + + InitialPosition + + + + + + + toDouble + double(it) + + + RewardFunction + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + IsActive + + + + + + + Source1 + + + + + + + + + + + + + ThisPatch + + + Value.StateIndex + + + + + + Item1.Item1,Item2,Item1.Item2.Amount,Item1.Item2.Probability,Item1.Item2.Available + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -1967,10 +2981,11 @@ - - - - + + + + + diff --git a/src/Extensions/SolveBlockTransitions.bonsai b/src/Extensions/SolveBlockTransitions.bonsai index 344ab306..c069e019 100644 --- a/src/Extensions/SolveBlockTransitions.bonsai +++ b/src/Extensions/SolveBlockTransitions.bonsai @@ -711,6 +711,9 @@ + + + @@ -723,6 +726,7 @@ + diff --git a/src/aind_behavior_vr_foraging/task_logic.py b/src/aind_behavior_vr_foraging/task_logic.py index d51fd0f2..b09276c6 100644 --- a/src/aind_behavior_vr_foraging/task_logic.py +++ b/src/aind_behavior_vr_foraging/task_logic.py @@ -385,6 +385,7 @@ class RewardFunctionRule(str, Enum): ON_TIME = "OnTime" ON_DISTANCE = "OnDistance" ON_THIS_PATCH_ENTRY = "OnThisPatchEntry" + ON_PATCH_ENTRY = "OnPatchEntry" ON_CHOICE_ACCUMULATED = "OnChoiceAccumulated" ON_REWARD_ACCUMULATED = "OnRewardAccumulated" ON_TIME_ACCUMULATED = "OnTimeAccumulated" @@ -449,24 +450,45 @@ class OutsideRewardFunction(_RewardFunction): ) -class OnThisPatchEntryFunction(_RewardFunction): +class OnThisPatchEntryRewardFunction(_RewardFunction): """ A RewardFunction that is applied when the animal enters the patch. """ - function_type: Literal["OnThisPatchEntryFunction"] = "OnThisPatchEntryFunction" + function_type: Literal["OnThisPatchEntryRewardFunction"] = "OnThisPatchEntryRewardFunction" rule: Literal[RewardFunctionRule.ON_THIS_PATCH_ENTRY] = Field( default=RewardFunctionRule.ON_THIS_PATCH_ENTRY, description="Rule to trigger reward function" ) +class PersistentRewardFunction(_RewardFunction): + """ + A RewardFunction that is always active. + """ + + function_type: Literal["PersistentRewardFunction"] = "PersistentRewardFunction" + rule: Literal[ + RewardFunctionRule.ON_REWARD, + RewardFunctionRule.ON_CHOICE, + RewardFunctionRule.ON_TIME, + RewardFunctionRule.ON_DISTANCE, + RewardFunctionRule.ON_CHOICE_ACCUMULATED, + RewardFunctionRule.ON_REWARD_ACCUMULATED, + RewardFunctionRule.ON_TIME_ACCUMULATED, + RewardFunctionRule.ON_DISTANCE_ACCUMULATED, + RewardFunctionRule.ON_PATCH_ENTRY, + ] + + if TYPE_CHECKING: - RewardFunction = Union[PatchRewardFunction, OutsideRewardFunction, OnThisPatchEntryFunction] + RewardFunction = Union[ + PatchRewardFunction, OutsideRewardFunction, OnThisPatchEntryRewardFunction, PersistentRewardFunction + ] else: RewardFunction = TypeAliasType( "RewardFunction", Annotated[ - Union[PatchRewardFunction, OutsideRewardFunction, OnThisPatchEntryFunction], + Union[PatchRewardFunction, OutsideRewardFunction, OnThisPatchEntryRewardFunction, PersistentRewardFunction], Field(discriminator="function_type"), ], ) @@ -931,7 +953,11 @@ class BlockEndConditionPatchCount(_BlockEndConditionBase): if TYPE_CHECKING: BlockEndCondition = Union[ - BlockEndConditionDuration, BlockEndConditionDistance, BlockEndConditionChoice, BlockEndConditionReward + BlockEndConditionDuration, + BlockEndConditionDistance, + BlockEndConditionChoice, + BlockEndConditionReward, + BlockEndConditionPatchCount, ] else: BlockEndCondition = TypeAliasType(