Skip to content

Commit d60db3a

Browse files
authored
Merge pull request #413 from AllenNeuralDynamics/feat-continuous-replenishment
Add persistent reward function
2 parents 231f70f + 491fba3 commit d60db3a

File tree

8 files changed

+1499
-31
lines changed

8 files changed

+1499
-31
lines changed

examples/task_mcm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def compute_cmc_transition_probability(n_states, rep_rate, T=3.5, dt=0.1) -> np.
5555

5656
operation_control = vr_task_logic.OperationControl(
5757
movable_spout_control=vr_task_logic.MovableSpoutControl(enabled=False),
58-
audio_control=vr_task_logic.AudioControl(duration=0.2, frequency=5000),
58+
audio_control=vr_task_logic.AudioControl(duration=0.2, frequency=9999),
5959
odor_control=vr_task_logic.OdorControl(),
6060
position_control=vr_task_logic.PositionControl(
6161
frequency_filter_cutoff=5,

examples/task_patch_foraging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def PostPatchVirtualSiteGeneratorHelper(contrast: float = 1, friction: float = 0
108108
rule=vr_task_logic.RewardFunctionRule.ON_REWARD,
109109
)
110110

111-
reset_function = vr_task_logic.OnThisPatchEntryFunction(
111+
reset_function = vr_task_logic.OnThisPatchEntryRewardFunction(
112112
available=vr_task_logic.SetValueFunction(value=vr_task_logic.scalar_value(0.1))
113113
)
114114

examples/test_single_site_patch.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import os
2+
from typing import Optional
3+
4+
import aind_behavior_services.task_logic.distributions as distributions
5+
from aind_behavior_curriculum import Stage, TrainerState
6+
7+
import aind_behavior_vr_foraging.task_logic as vr_task_logic
8+
from aind_behavior_vr_foraging.task_logic import (
9+
AindVrForagingTaskLogic,
10+
AindVrForagingTaskParameters,
11+
)
12+
13+
MINIMUM_INTERPATCH_LENGTH = 50
14+
MEAN_INTERPATCH_LENGTH = 150
15+
MAXIMUM_INTERPATCH_LENGTH = 500
16+
INTERSITE_LENGTH = 50
17+
REWARDSITE_LENGTH = 50
18+
REWARD_AMOUNT = 3
19+
VELOCITY_THRESHOLD = 15 # cm/s
20+
21+
P_REWARD_BLOCK: list[tuple[float, Optional[float], Optional[float]]] = [
22+
(1.0, 1.0, None),
23+
(0.8, 0.8, None),
24+
(0.8, 0.2, None),
25+
]
26+
27+
P_BAIT_BLOCK = [
28+
(1.0, 1.0, None),
29+
(0.4, 0.4, None),
30+
(0.4, 0.1, None),
31+
]
32+
33+
34+
def make_patch(
35+
label: str,
36+
state_index: int,
37+
odor_index: int,
38+
p_reward: float,
39+
p_replenish: float,
40+
):
41+
baiting_function = vr_task_logic.PersistentRewardFunction(
42+
rule=vr_task_logic.RewardFunctionRule.ON_PATCH_ENTRY,
43+
probability=vr_task_logic.SetValueFunction(
44+
value=distributions.BinomialDistribution(
45+
distribution_parameters=distributions.BinomialDistributionParameters(n=1, p=p_replenish),
46+
scaling_parameters=distributions.ScalingParameters(offset=p_reward),
47+
truncation_parameters=distributions.TruncationParameters(min=p_reward, max=1),
48+
),
49+
),
50+
)
51+
52+
depletion_function = vr_task_logic.PatchRewardFunction(
53+
probability=vr_task_logic.SetValueFunction(
54+
value=vr_task_logic.scalar_value(p_reward),
55+
),
56+
rule=vr_task_logic.RewardFunctionRule.ON_REWARD,
57+
)
58+
59+
return vr_task_logic.Patch(
60+
label=label,
61+
state_index=state_index,
62+
odor_specification=vr_task_logic.OdorSpecification(index=odor_index, concentration=1),
63+
patch_terminators=[
64+
vr_task_logic.PatchTerminatorOnChoice(count=vr_task_logic.scalar_value(1)),
65+
vr_task_logic.PatchTerminatorOnRejection(count=vr_task_logic.scalar_value(1)),
66+
],
67+
reward_specification=vr_task_logic.RewardSpecification(
68+
amount=vr_task_logic.scalar_value(REWARD_AMOUNT),
69+
probability=vr_task_logic.scalar_value(p_reward),
70+
available=vr_task_logic.scalar_value(999999),
71+
delay=vr_task_logic.scalar_value(0.5),
72+
operant_logic=vr_task_logic.OperantLogic(
73+
is_operant=False,
74+
stop_duration=0.5,
75+
time_to_collect_reward=100000,
76+
grace_distance_threshold=10,
77+
),
78+
reward_function=[baiting_function, depletion_function],
79+
),
80+
patch_virtual_sites_generator=vr_task_logic.PatchVirtualSitesGenerator(
81+
inter_patch=vr_task_logic.VirtualSiteGenerator(
82+
render_specification=vr_task_logic.RenderSpecification(contrast=1),
83+
label=vr_task_logic.VirtualSiteLabels.INTERPATCH,
84+
length_distribution=distributions.ExponentialDistribution(
85+
distribution_parameters=distributions.ExponentialDistributionParameters(
86+
rate=1 / MEAN_INTERPATCH_LENGTH
87+
),
88+
scaling_parameters=distributions.ScalingParameters(offset=MINIMUM_INTERPATCH_LENGTH),
89+
truncation_parameters=distributions.TruncationParameters(
90+
min=MINIMUM_INTERPATCH_LENGTH,
91+
max=MAXIMUM_INTERPATCH_LENGTH,
92+
),
93+
),
94+
),
95+
inter_site=vr_task_logic.VirtualSiteGenerator(
96+
render_specification=vr_task_logic.RenderSpecification(contrast=0.5),
97+
label=vr_task_logic.VirtualSiteLabels.INTERSITE,
98+
length_distribution=vr_task_logic.scalar_value(INTERSITE_LENGTH),
99+
),
100+
reward_site=vr_task_logic.VirtualSiteGenerator(
101+
render_specification=vr_task_logic.RenderSpecification(contrast=0.5),
102+
label=vr_task_logic.VirtualSiteLabels.REWARDSITE,
103+
length_distribution=vr_task_logic.scalar_value(REWARDSITE_LENGTH),
104+
),
105+
),
106+
)
107+
108+
109+
def make_block(
110+
p_rew: tuple[float, Optional[float], Optional[float]],
111+
p_replenish: tuple[float, Optional[float], Optional[float]],
112+
n_min_trials: int = 100,
113+
) -> vr_task_logic.Block:
114+
patches = [make_patch(label="OdorA", state_index=0, odor_index=0, p_reward=p_rew[0], p_replenish=p_replenish[0])]
115+
if p_rew[1] is not None:
116+
assert p_replenish[1] is not None
117+
patches.append(
118+
make_patch(label="OdorB", state_index=1, odor_index=1, p_reward=p_rew[1], p_replenish=p_replenish[1])
119+
)
120+
if p_rew[2] is not None:
121+
assert p_replenish[2] is not None
122+
patches.append(
123+
make_patch(label="OdorC", state_index=2, odor_index=2, p_reward=p_rew[2], p_replenish=p_replenish[2])
124+
)
125+
126+
per_p = 1.0 / len(patches)
127+
return vr_task_logic.Block(
128+
environment_statistics=vr_task_logic.EnvironmentStatistics(
129+
first_state_occupancy=[per_p] * len(patches),
130+
transition_matrix=[[per_p] * len(patches) for _ in range(len(patches))],
131+
patches=patches,
132+
),
133+
end_conditions=[
134+
vr_task_logic.BlockEndConditionPatchCount(
135+
value=distributions.ExponentialDistribution(
136+
distribution_parameters=distributions.ExponentialDistributionParameters(rate=1 / 25),
137+
scaling_parameters=distributions.ScalingParameters(offset=n_min_trials),
138+
truncation_parameters=distributions.TruncationParameters(min=n_min_trials, max=n_min_trials + 50),
139+
)
140+
)
141+
],
142+
)
143+
144+
145+
operation_control = vr_task_logic.OperationControl(
146+
movable_spout_control=vr_task_logic.MovableSpoutControl(enabled=False),
147+
audio_control=vr_task_logic.AudioControl(duration=0.2, frequency=9999),
148+
odor_control=vr_task_logic.OdorControl(),
149+
position_control=vr_task_logic.PositionControl(
150+
frequency_filter_cutoff=5,
151+
velocity_threshold=VELOCITY_THRESHOLD,
152+
),
153+
)
154+
155+
156+
task_logic = AindVrForagingTaskLogic(
157+
task_parameters=AindVrForagingTaskParameters(
158+
rng_seed=None,
159+
environment=vr_task_logic.BlockStructure(
160+
blocks=[
161+
make_block(p_rew=P_REWARD_BLOCK[i], p_replenish=P_BAIT_BLOCK[i], n_min_trials=100)
162+
for i in range(len(P_REWARD_BLOCK))
163+
],
164+
sampling_mode="Sequential",
165+
),
166+
operation_control=operation_control,
167+
),
168+
stage_name="single_site_patch",
169+
)
170+
171+
172+
def main(path_seed: str = "./local/SingleSitePatch_{schema}.json"):
173+
example_task_logic = task_logic
174+
example_trainer_state = TrainerState(
175+
stage=Stage(name="example_stage", task=example_task_logic), curriculum=None, is_on_curriculum=False
176+
)
177+
os.makedirs(os.path.dirname(path_seed), exist_ok=True)
178+
models = [example_task_logic, example_trainer_state]
179+
180+
for model in models:
181+
with open(path_seed.format(schema=model.__class__.__name__), "w", encoding="utf-8") as f:
182+
f.write(model.model_dump_json(indent=2))
183+
184+
185+
if __name__ == "__main__":
186+
main()

src/DataSchemas/aind_behavior_vr_foraging.json

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2927,12 +2927,12 @@
29272927
"title": "OlfactometerChannelType",
29282928
"type": "string"
29292929
},
2930-
"OnThisPatchEntryFunction": {
2930+
"OnThisPatchEntryRewardFunction": {
29312931
"description": "A RewardFunction that is applied when the animal enters the patch.",
29322932
"properties": {
29332933
"function_type": {
2934-
"const": "OnThisPatchEntryFunction",
2935-
"default": "OnThisPatchEntryFunction",
2934+
"const": "OnThisPatchEntryRewardFunction",
2935+
"default": "OnThisPatchEntryRewardFunction",
29362936
"title": "Function Type",
29372937
"type": "string"
29382938
},
@@ -2980,7 +2980,7 @@
29802980
"type": "string"
29812981
}
29822982
},
2983-
"title": "OnThisPatchEntryFunction",
2983+
"title": "OnThisPatchEntryRewardFunction",
29842984
"type": "object"
29852985
},
29862986
"OperantLogic": {
@@ -3694,6 +3694,73 @@
36943694
"title": "PdfDistributionParameters",
36953695
"type": "object"
36963696
},
3697+
"PersistentRewardFunction": {
3698+
"description": "A RewardFunction that is always active.",
3699+
"properties": {
3700+
"function_type": {
3701+
"const": "PersistentRewardFunction",
3702+
"default": "PersistentRewardFunction",
3703+
"title": "Function Type",
3704+
"type": "string"
3705+
},
3706+
"amount": {
3707+
"default": null,
3708+
"description": "Defines the amount of reward replenished per rule unit.",
3709+
"oneOf": [
3710+
{
3711+
"$ref": "#/$defs/PatchUpdateFunction"
3712+
},
3713+
{
3714+
"type": "null"
3715+
}
3716+
]
3717+
},
3718+
"probability": {
3719+
"default": null,
3720+
"description": "Defines the probability of reward replenished per rule unit.",
3721+
"oneOf": [
3722+
{
3723+
"$ref": "#/$defs/PatchUpdateFunction"
3724+
},
3725+
{
3726+
"type": "null"
3727+
}
3728+
]
3729+
},
3730+
"available": {
3731+
"default": null,
3732+
"description": "Defines the amount of reward available replenished in the patch per rule unit.",
3733+
"oneOf": [
3734+
{
3735+
"$ref": "#/$defs/PatchUpdateFunction"
3736+
},
3737+
{
3738+
"type": "null"
3739+
}
3740+
]
3741+
},
3742+
"rule": {
3743+
"enum": [
3744+
"OnReward",
3745+
"OnChoice",
3746+
"OnTime",
3747+
"OnDistance",
3748+
"OnChoiceAccumulated",
3749+
"OnRewardAccumulated",
3750+
"OnTimeAccumulated",
3751+
"OnDistanceAccumulated",
3752+
"OnPatchEntry"
3753+
],
3754+
"title": "Rule",
3755+
"type": "string"
3756+
}
3757+
},
3758+
"required": [
3759+
"rule"
3760+
],
3761+
"title": "PersistentRewardFunction",
3762+
"type": "object"
3763+
},
36973764
"PoissonDistribution": {
36983765
"properties": {
36993766
"family": {
@@ -3856,9 +3923,10 @@
38563923
"RewardFunction": {
38573924
"discriminator": {
38583925
"mapping": {
3859-
"OnThisPatchEntryFunction": "#/$defs/OnThisPatchEntryFunction",
3926+
"OnThisPatchEntryRewardFunction": "#/$defs/OnThisPatchEntryRewardFunction",
38603927
"OutsideRewardFunction": "#/$defs/OutsideRewardFunction",
3861-
"PatchRewardFunction": "#/$defs/PatchRewardFunction"
3928+
"PatchRewardFunction": "#/$defs/PatchRewardFunction",
3929+
"PersistentRewardFunction": "#/$defs/PersistentRewardFunction"
38623930
},
38633931
"propertyName": "function_type"
38643932
},
@@ -3870,7 +3938,10 @@
38703938
"$ref": "#/$defs/OutsideRewardFunction"
38713939
},
38723940
{
3873-
"$ref": "#/$defs/OnThisPatchEntryFunction"
3941+
"$ref": "#/$defs/OnThisPatchEntryRewardFunction"
3942+
},
3943+
{
3944+
"$ref": "#/$defs/PersistentRewardFunction"
38743945
}
38753946
]
38763947
},

0 commit comments

Comments
 (0)