Skip to content

Commit f412ce1

Browse files
committed
clean-up, addition of TrialDataModel, argument parsing, ruff
1 parent b68545e commit f412ce1

File tree

2 files changed

+93
-92
lines changed

2 files changed

+93
-92
lines changed
Lines changed: 91 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,215 +1,215 @@
1-
import numpy as np
2-
import pandas as pd
3-
from pybpodapi.protocol import StateMachine
1+
from pathlib import Path
2+
from typing import Annotated
3+
4+
import yaml
5+
from annotated_types import Interval
46

57
import iblrig.misc
6-
from iblrig.base_choice_world import BiasedChoiceWorldSession
8+
from iblrig.base_choice_world import BiasedChoiceWorldSession, BiasedChoiceWorldTrialData
79
from iblrig.hardware import SOFTCODE
810
from iblutil.util import setup_logger
11+
from pybpodapi.protocol import StateMachine
912

1013
log = setup_logger(__name__)
1114

12-
INTERACTIVE_DELAY = 1.0
13-
NTRIALS_INIT = 2000
15+
# read defaults from task_parameters.yaml
16+
with open(Path(__file__).parent.joinpath('task_parameters.yaml')) as f:
17+
DEFAULTS = yaml.safe_load(f)
1418

1519

16-
class Session(BiasedChoiceWorldSession):
20+
class CuedBiasedChoiceWorldTrialData(BiasedChoiceWorldTrialData):
21+
"""Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.BiasedChoiceWorldTrialData`."""
1722

23+
play_audio_cue: bool
24+
25+
26+
class Session(BiasedChoiceWorldSession):
1827
protocol_name = 'samuel_cuedBiasedChoiceWorld'
28+
TrialDataModel = CuedBiasedChoiceWorldTrialData
1929

20-
def __init__(self, *args, **kwargs):
30+
def __init__(self, *args, probability_audio_cue: float = 1.0, **kwargs):
2131
super().__init__(**kwargs)
2232

33+
# store parameters to task_params
34+
self.session_info['PROBABILITY_AUDIO_CUE'] = probability_audio_cue
35+
2336
# loads in the settings in order to determine the main sync and thus the pipeline extractor tasks
2437
is_main_sync = self.hardware_settings.get('MAIN_SYNC', False)
2538
trials_task = 'CuedBiasedTrials' if is_main_sync else 'CuedBiasedTrialsTimeline'
2639
self.extractor_tasks = ['TrialRegisterRaw', trials_task, 'TrainingStatus']
40+
2741
# Update experiment description which was created by superclass init
2842
self.experiment_description['tasks'][-1][self.protocol_name]['extractors'] = self.extractor_tasks
2943

30-
# init behaviour data
31-
self.movement_left = self.device_rotary_encoder.THRESHOLD_EVENTS[
32-
self.task_params.QUIESCENCE_THRESHOLDS[0]]
33-
self.movement_right = self.device_rotary_encoder.THRESHOLD_EVENTS[
34-
self.task_params.QUIESCENCE_THRESHOLDS[1]]
35-
# init counter variables
36-
self.trial_num = -1
37-
self.block_num = -1
38-
self.block_trial_num = -1
39-
# init the tables, there are 2 of them: a trials table and a ambient sensor data table
40-
self.trials_table = pd.DataFrame({
41-
'contrast': np.zeros(NTRIALS_INIT) * np.NaN,
42-
'position': np.zeros(NTRIALS_INIT) * np.NaN,
43-
'quiescent_period': np.zeros(NTRIALS_INIT) * np.NaN,
44-
'response_side': np.zeros(NTRIALS_INIT, dtype=np.int8),
45-
'response_time': np.zeros(NTRIALS_INIT) * np.NaN,
46-
'reward_amount': np.zeros(NTRIALS_INIT) * np.NaN,
47-
'reward_valve_time': np.zeros(NTRIALS_INIT) * np.NaN,
48-
'stim_angle': np.zeros(NTRIALS_INIT) * np.NaN,
49-
'stim_freq': np.zeros(NTRIALS_INIT) * np.NaN,
50-
'stim_gain': np.zeros(NTRIALS_INIT) * np.NaN,
51-
'stim_phase': np.zeros(NTRIALS_INIT) * np.NaN,
52-
'stim_reverse': np.zeros(NTRIALS_INIT, dtype=bool),
53-
'stim_sigma': np.zeros(NTRIALS_INIT) * np.NaN,
54-
'trial_correct': np.zeros(NTRIALS_INIT, dtype=bool),
55-
'trial_num': np.zeros(NTRIALS_INIT, dtype=np.int16),
56-
})
57-
5844
def get_state_machine_trial(self, i):
5945
sma = StateMachine(self.bpod)
6046
if i == 0: # First trial exception start camera
61-
session_delay_start = self.task_params.get("SESSION_DELAY_START", 0)
62-
log.info("First trial initializing, will move to next trial only if:")
63-
log.info("1. camera is detected")
64-
log.info(f"2. {session_delay_start} sec have elapsed")
47+
session_delay_start = self.task_params.get('SESSION_DELAY_START', 0)
48+
log.info('First trial initializing, will move to next trial only if:')
49+
log.info('1. camera is detected')
50+
log.info(f'2. {session_delay_start} sec have elapsed')
6551
sma.add_state(
66-
state_name="trial_start",
52+
state_name='trial_start',
6753
state_timer=0,
68-
state_change_conditions={"Port1In": "delay_initiation"},
69-
output_actions=[("SoftCode", SOFTCODE.TRIGGER_CAMERA), ("BNC1", 255)],
54+
state_change_conditions={'Port1In': 'delay_initiation'},
55+
output_actions=[('SoftCode', SOFTCODE.TRIGGER_CAMERA), ('BNC1', 255)],
7056
) # start camera
7157
sma.add_state(
72-
state_name="delay_initiation",
58+
state_name='delay_initiation',
7359
state_timer=session_delay_start,
7460
output_actions=[],
75-
state_change_conditions={"Tup": "reset_rotary_encoder"},
61+
state_change_conditions={'Tup': 'reset_rotary_encoder'},
7662
)
7763
else:
7864
sma.add_state(
79-
state_name="trial_start",
65+
state_name='trial_start',
8066
state_timer=0, # ~100µs hardware irreducible delay
81-
state_change_conditions={"Tup": "reset_rotary_encoder"},
82-
output_actions=[self.bpod.actions.stop_sound, ("BNC1", 255)],
67+
state_change_conditions={'Tup': 'reset_rotary_encoder'},
68+
output_actions=[self.bpod.actions.stop_sound, ('BNC1', 255)],
8369
) # stop all sounds
8470

8571
sma.add_state(
86-
state_name="reset_rotary_encoder",
72+
state_name='reset_rotary_encoder',
8773
state_timer=0,
8874
output_actions=[self.bpod.actions.rotary_encoder_reset],
89-
state_change_conditions={"Tup": "quiescent_period"},
75+
state_change_conditions={'Tup': 'quiescent_period'},
9076
)
9177

9278
sma.add_state( # '>back' | '>reset_timer'
93-
state_name="quiescent_period",
79+
state_name='quiescent_period',
9480
state_timer=self.quiescent_period,
9581
output_actions=[],
9682
state_change_conditions={
97-
"Tup": "play_tone",
98-
self.movement_left: "reset_rotary_encoder",
99-
self.movement_right: "reset_rotary_encoder",
83+
'Tup': 'play_tone',
84+
self.movement_left: 'reset_rotary_encoder',
85+
self.movement_right: 'reset_rotary_encoder',
10086
},
10187
)
10288
# play tone, move on to next state if sound is detected, with a time-out of 0.1s
10389
# SP how can we make sure the delay between play_tone and stim_on is always exactly 1s?
10490
sma.add_state(
105-
state_name="play_tone",
91+
state_name='play_tone',
10692
state_timer=0.1, # SP is this necessary??
10793
output_actions=[self.bpod.actions.play_tone],
10894
state_change_conditions={
109-
"Tup": "interactive_delay",
110-
"BNC2High": "interactive_delay",
95+
'Tup': 'interactive_delay',
96+
'BNC2High': 'interactive_delay',
11197
},
11298
)
11399
# this will add a delay between auditory cue and visual stimulus
114100
# this needs to be precise and accurate based on the parameter
115101
sma.add_state(
116-
state_name="interactive_delay",
102+
state_name='interactive_delay',
117103
state_timer=self.task_params.INTERACTIVE_DELAY,
118104
output_actions=[],
119-
state_change_conditions={"Tup": "stim_on"},
105+
state_change_conditions={'Tup': 'stim_on'},
120106
)
121107
# show stimulus, move on to next state if a frame2ttl is detected, with a time-out of 0.1s
122108
sma.add_state(
123-
state_name="stim_on",
109+
state_name='stim_on',
124110
state_timer=0.1,
125111
output_actions=[self.bpod.actions.bonsai_show_stim],
126112
state_change_conditions={
127-
"Tup": "reset2_rotary_encoder",
128-
"BNC1High": "reset2_rotary_encoder",
129-
"BNC1Low": "reset2_rotary_encoder",
113+
'Tup': 'reset2_rotary_encoder',
114+
'BNC1High': 'reset2_rotary_encoder',
115+
'BNC1Low': 'reset2_rotary_encoder',
130116
},
131117
)
118+
132119
sma.add_state(
133-
state_name="reset2_rotary_encoder",
120+
state_name='reset2_rotary_encoder',
134121
state_timer=0.05, # the delay here is to avoid race conditions in the bonsai flow
135122
output_actions=[self.bpod.actions.rotary_encoder_reset],
136-
state_change_conditions={"Tup": "closed_loop"},
123+
state_change_conditions={'Tup': 'closed_loop'},
137124
)
138125

139126
sma.add_state(
140-
state_name="closed_loop",
127+
state_name='closed_loop',
141128
state_timer=self.task_params.RESPONSE_WINDOW,
142129
output_actions=[self.bpod.actions.bonsai_closed_loop],
143130
state_change_conditions={
144-
"Tup": "no_go",
145-
self.event_error: "freeze_error",
146-
self.event_reward: "freeze_reward",
131+
'Tup': 'no_go',
132+
self.event_error: 'freeze_error',
133+
self.event_reward: 'freeze_reward',
147134
},
148135
)
149136

150137
sma.add_state(
151-
state_name="no_go",
138+
state_name='no_go',
152139
state_timer=self.task_params.FEEDBACK_NOGO_DELAY_SECS,
153140
output_actions=[self.bpod.actions.bonsai_hide_stim, self.bpod.actions.play_noise],
154-
state_change_conditions={"Tup": "exit_state"},
141+
state_change_conditions={'Tup': 'exit_state'},
155142
)
156143

157144
sma.add_state(
158-
state_name="freeze_error",
145+
state_name='freeze_error',
159146
state_timer=0,
160147
output_actions=[self.bpod.actions.bonsai_freeze_stim],
161-
state_change_conditions={"Tup": "error"},
148+
state_change_conditions={'Tup': 'error'},
162149
)
163150

164151
sma.add_state(
165-
state_name="error",
152+
state_name='error',
166153
state_timer=self.task_params.FEEDBACK_ERROR_DELAY_SECS,
167154
output_actions=[self.bpod.actions.play_noise],
168-
state_change_conditions={"Tup": "hide_stim"},
155+
state_change_conditions={'Tup': 'hide_stim'},
169156
)
170157

171158
sma.add_state(
172-
state_name="freeze_reward",
159+
state_name='freeze_reward',
173160
state_timer=0,
174161
output_actions=[self.bpod.actions.bonsai_freeze_stim],
175-
state_change_conditions={"Tup": "reward"},
162+
state_change_conditions={'Tup': 'reward'},
176163
)
177164

178165
sma.add_state(
179-
state_name="reward",
166+
state_name='reward',
180167
state_timer=self.reward_time,
181-
output_actions=[("Valve1", 255), ("BNC1", 255)],
182-
state_change_conditions={"Tup": "correct"},
168+
output_actions=[('Valve1', 255), ('BNC1', 255)],
169+
state_change_conditions={'Tup': 'correct'},
183170
)
184171

185172
sma.add_state(
186-
state_name="correct",
173+
state_name='correct',
187174
state_timer=self.task_params.FEEDBACK_CORRECT_DELAY_SECS,
188175
output_actions=[],
189-
state_change_conditions={"Tup": "hide_stim"},
176+
state_change_conditions={'Tup': 'hide_stim'},
190177
)
191178

192179
sma.add_state(
193-
state_name="hide_stim",
180+
state_name='hide_stim',
194181
state_timer=0.1,
195182
output_actions=[self.bpod.actions.bonsai_hide_stim],
196183
state_change_conditions={
197-
"Tup": "exit_state",
198-
"BNC1High": "exit_state",
199-
"BNC1Low": "exit_state",
184+
'Tup': 'exit_state',
185+
'BNC1High': 'exit_state',
186+
'BNC1Low': 'exit_state',
200187
},
201188
)
202189

203190
sma.add_state(
204-
state_name="exit_state",
191+
state_name='exit_state',
205192
state_timer=self.task_params.ITI_DELAY_SECS,
206-
output_actions=[("BNC1", 255)],
207-
state_change_conditions={"Tup": "exit"},
193+
output_actions=[('BNC1', 255)],
194+
state_change_conditions={'Tup': 'exit'},
208195
)
209196
return sma
210197

198+
@staticmethod
199+
def extra_parser():
200+
parser = super(Session, Session).extra_parser()
201+
parser.add_argument(
202+
'--probability_audio_cue',
203+
option_strings=['--probability_audio_cue'],
204+
dest='probability_audio_cue',
205+
default=DEFAULTS.get('PROBABILITY_AUDIO_CUE', 1.0),
206+
type=float,
207+
help='defines the probability of the audio cue to be played',
208+
)
209+
return parser
210+
211211

212212
if __name__ == '__main__': # pragma: no cover
213-
kwargs = iblrig.misc.get_task_arguments(parents=[Session.extra_parser()])
214-
sess = Session(**kwargs)
213+
task_kwargs = iblrig.misc.get_task_arguments(parents=[Session.extra_parser()])
214+
sess = Session(**task_kwargs)
215215
sess.run()
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
'INTERACTIVE_DELAY': 1
1+
'INTERACTIVE_DELAY': 1
2+
'PROBABILITY_AUDIO_CUE': 1

0 commit comments

Comments
 (0)