Skip to content

Commit 6223174

Browse files
committed
added experiment prototype
1 parent 6dfddd3 commit 6223174

File tree

2 files changed

+211
-27
lines changed

2 files changed

+211
-27
lines changed

experiments/base/experiments.py

Lines changed: 127 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from experiments.base.stimulus_process import Timer
1313
from experiments.utils.exp_setup import get_experiment_settings, setup_trigger, setup_process
1414
from utils.plotter import plot_triggers_response
15+
import random
1516

1617

1718
class BaseExperiment:
@@ -138,7 +139,6 @@ def check_skeleton(self, frame, skeleton):
138139
self._trial_timers[trial].start()
139140

140141
self._process.put(self._current_trial)
141-
return result, response
142142

143143
@property
144144
def _trials(self):
@@ -184,6 +184,120 @@ def get_trial(self):
184184
return self._current_trial
185185

186186

187+
class BaseTrialExperiment(BaseExperiment):
188+
def __init__(self):
189+
super().__init__()
190+
self._name = 'BaseTrialExperiment'
191+
self.experiment_finished = False
192+
self._event = None
193+
self._print_check = False
194+
self._current_trial = None
195+
self._result_list = []
196+
self._success_count = 0
197+
198+
self._parameter_dict = dict(TRIGGER = 'str',
199+
PROCESS = 'str',
200+
INTERTRIAL_TIME = 'int',
201+
TRIAL_TRIGGER = 'str',
202+
TRIAL_TIME = 'int',
203+
STIMULUS_TIME = 'int',
204+
RESULT_FUNC = 'str',
205+
EXP_LENGTH = 'int',
206+
EXP_COMPLETION = 'int',
207+
EXP_TIME = 'int')
208+
209+
self._settings_dict = get_experiment_settings(self._name, self._parameter_dict)
210+
self._process = setup_process(self._settings_dict['PROCESS'])
211+
self._init_trigger = setup_trigger(self._settings_dict['TRIGGER'])
212+
self._trials_list = self.generate_trials_list(self._trials, self._settings_dict['EXP_LENGTH'])
213+
self._trial_timer = Timer(self._settings_dict['TRIAL_TIME'])
214+
self._exp_timer = Timer(self._settings_dict['EXP_TIME'])
215+
self._intertrial_timer = Timer(self._settings_dict['INTERTRIAL_TIME'])
216+
217+
218+
219+
def check_skeleton(self, frame, skeleton):
220+
status, trial = self._process.get_status()
221+
if status:
222+
current_trial = self._trials[trial]
223+
condition, response = current_trial['trigger'].check_skeleton(skeleton)
224+
self._process.pass_condition(condition)
225+
result = self._process.get_result()
226+
if result is not None:
227+
self.process_result(result, trial)
228+
self._current_trial = None
229+
# check if all trials were successful until completion
230+
if self._success_count >= self._settings_dict['EXP_COMPLETION']:
231+
print("Experiment is finished")
232+
print("Trial reached required amount of successes")
233+
self.stop_experiment()
234+
235+
# if not continue
236+
print(' Going into Intertrial time.')
237+
self._intertrial_timer.reset()
238+
self._intertrial_timer.start()
239+
result = None
240+
plot_triggers_response(frame, response)
241+
242+
elif not self._intertrial_timer.check_timer():
243+
if self._current_trial is None:
244+
self._current_trial = next(self._trials_list,False)
245+
elif not self._current_trial:
246+
print("Experiment is finished due to max. trial repetition.")
247+
print(self._result_list)
248+
self.stop_experiment()
249+
else:
250+
init_result, response_body = self._init_trigger.check_skeleton(skeleton)
251+
if init_result:
252+
# check trial start triggers
253+
self._process.set_trial(self._current_trial)
254+
self._print_check = False
255+
elif not self._print_check:
256+
print('Next trial: #' + str(len(self._result_list) + 1) + ' ' + self._current_trial)
257+
print('Animal is not meeting trial start criteria, the start of trial is delayed.')
258+
self._print_check = True
259+
# if experimental time ran out, finish experiments
260+
super().check_exp_timer()
261+
262+
def process_result(self, result, trial):
263+
"""
264+
Will add result if TRUE or reset comp_counter if FALSE
265+
:param result: bool if trial was successful
266+
:param trial: str name of the trial
267+
:return:
268+
"""
269+
self._result_list.append((trial, result))
270+
if result is True:
271+
self._success_count +=1
272+
print('Trial successful!')
273+
else:
274+
print('Trial failed.')
275+
#
276+
277+
@staticmethod
278+
def generate_trials_list(trials: dict, length: int):
279+
trials_list = []
280+
for trial in range(length):
281+
trials_list.append(random.choice(list(trials.keys())))
282+
return iter(trials_list)
283+
284+
@property
285+
def _trials(self):
286+
287+
trigger = setup_trigger(self._settings_dict['TRIAL_TRIGGER'])
288+
if self._settings_dict['RESULT_FUNC'] == 'all':
289+
result_func = all
290+
elif self._settings_dict['RESULT_FUNC'] == 'any':
291+
result_func = any
292+
else:
293+
raise ValueError(f'Result function can only be "all" or "any", not {self._settings_dict["RESULT_FUNC"]}.')
294+
trials = {'Trial': dict(stimulus_timer=Timer(self._settings_dict['STIMULUS_TIME']),
295+
success_timer=Timer(self._settings_dict['TRIAL_TIME']),
296+
trigger=trigger,
297+
result_func=result_func)}
298+
299+
return trials
300+
187301

188302
class BaseConditionalExperiment(BaseExperiment):
189303
"""
@@ -208,6 +322,7 @@ def __init__(self):
208322
self._current_trial = None
209323

210324
self._exp_timer = Timer(self._settings_dict['EXP_TIME'])
325+
self._intertrial_timer = Timer(self._settings_dict['INTERTRIAL_TIME'])
211326

212327
self._trigger = setup_trigger(self._settings_dict['TRIGGER'])
213328

@@ -225,16 +340,17 @@ def check_skeleton(self, frame, skeleton):
225340
self.stop_experiment()
226341

227342
elif not self.experiment_finished:
343+
if not self._intertrial_timer.check_timer():
344+
# check if condition is met
345+
result, response = self._trigger.check_skeleton(skeleton=skeleton)
346+
if result:
347+
self._event_count += 1
348+
print('Stimulation #{self._event_count}'.format())
349+
self._intertrial_timer.reset()
350+
self._intertrial_timer.start()
228351

229-
# check if condition is met
230-
result, response = self._trigger.check_skeleton(skeleton=skeleton)
231-
plot_triggers_response(frame, response)
232-
if result:
233-
self._event_count += 1
234-
print('Stimulation #{self._event_count}'.format())
235-
236-
self._process.put(result)
237-
return result, response
352+
plot_triggers_response(frame, response)
353+
self._process.put(result)
238354

239355
def check_exp_timer(self):
240356
"""
@@ -269,12 +385,9 @@ def get_trial(self):
269385
return self._current_trial
270386

271387

272-
273-
274-
275-
276388
"""Standardexperiments that can be setup by using the experiment config"""
277389

390+
278391
class BaseOptogeneticExperiment(BaseExperiment):
279392
"""Standard implementation of an optogenetic experiment"""
280393

experiments/base/stimulus_process.py

Lines changed: 84 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def base_conditional_supply_protocol_run(condition_q: mp.Queue, stimulus_name):
8888
stimulation.remove()
8989

9090

91-
def base_trial_protocol_run(trial_q: mp.Queue, success_q: mp.Queue, trials: dict):
91+
def base_trial_protocol_run(trial_q: mp.Queue, condition_q: mp.Queue, success_q: mp.Queue, stimulation_name):
9292
"""
9393
The function to use in ProtocolProcess class
9494
Designed to be run continuously alongside the main loop
@@ -99,22 +99,91 @@ def base_trial_protocol_run(trial_q: mp.Queue, success_q: mp.Queue, trials: dict
9999
:param stimulus_name: exact name of stimulus function in base_stimulus.py
100100
"""
101101
current_trial = None
102-
# TODO: make this adaptive and working
103-
trial_dict = {}
104-
stimulus_name = 'BaseStimulation'
105-
stimulation = setup_stimulation(stimulus_name)
102+
stimulation = setup_stimulation(stimulation_name)
106103
# starting the main loop without any protocol running
107104
while True:
108105
if trial_q.empty() and current_trial is None:
109106
pass
110107
elif trial_q.full():
111-
current_trial = trial_q.get()
112-
print(current_trial)
108+
finished_trial = False
109+
# starting timers
110+
stimulus_timer = trials[current_trial]['stimulus_timer']
111+
success_timer = trials[current_trial]['success_timer']
112+
print('Starting protocol {}'.format(current_trial))
113+
stimulus_timer.start()
114+
success_timer.start()
115+
condition_list = []
113116
# this branch is for already running protocol
114117
elif current_trial is not None:
115-
success_q.put(True)
116-
stimulation.stimulate()
117-
current_trial = None
118+
# checking for stimulus timer and outputting correct image
119+
if stimulus_timer.check_timer():
120+
# if stimulus timer is running, show stimulus
121+
stimulation.stimulate()
122+
else:
123+
# if the timer runs out, finish protocol and reset timer
124+
trials[current_trial]['stimulus_timer'].reset()
125+
current_trial = None
126+
127+
# checking if any condition was passed
128+
if condition_q.full():
129+
stimulus_condition = condition_q.get()
130+
# checking if timer for condition is running and condition=True
131+
if success_timer.check_timer():
132+
# print('That was a success!')
133+
condition_list.append(stimulus_condition)
134+
# elif success_timer.check_timer() and not stimulus_condition:
135+
# # print('That was not a success')
136+
# condition_list.append(False)
137+
138+
# checking if the timer for condition has run out
139+
if not success_timer.check_timer() and not finished_trial:
140+
if CTRL:
141+
# start a random time interval
142+
# TODO: working ctrl timer that does not set new time each frame...
143+
ctrl_time = random.randint(0,INTERTRIAL_TIME + 1)
144+
ctrl_timer = Timer(ctrl_time)
145+
ctrl_timer.start()
146+
print('Waiting for extra' + str(ctrl_time) + ' sec')
147+
if not ctrl_timer.check_timer():
148+
# in ctrl just randomly decide between the two
149+
print('Random choice between both stimuli')
150+
if random.random() >= 0.5:
151+
# very fast random choice between TRUE and FALSE
152+
deliver_liqreward()
153+
print('Delivered Reward')
154+
155+
else:
156+
deliver_tone_shock()
157+
print('Delivered Aversive')
158+
159+
ctrl_timer.reset()
160+
finished_trial = True
161+
# outputting the result, whatever it is
162+
success = trials[current_trial]['result_func'](condition_list)
163+
success_q.put(success)
164+
trials[current_trial]['success_timer'].reset()
165+
166+
else:
167+
if current_trial == 'Bluebar_whiteback':
168+
deliver_tone_shock()
169+
print('Delivered Aversive')
170+
elif current_trial == 'Greenbar_whiteback':
171+
if trials[current_trial]['random_reward']:
172+
if random.random() >= 0.5:
173+
# very fast random choice between TRUE and FALSE
174+
deliver_liqreward()
175+
print('Delivered Reward')
176+
else:
177+
print('No Reward')
178+
else:
179+
deliver_liqreward()
180+
# resetting the timer
181+
print('Timer for condition run out')
182+
finished_trial = True
183+
# outputting the result, whatever it is
184+
success = trials[current_trial]['result_func'](condition_list)
185+
success_q.put(success)
186+
trials[current_trial]['success_timer'].reset()
118187

119188

120189
class BaseProtocolProcess:
@@ -133,8 +202,10 @@ def __init__(self, trials: dict = None):
133202
if self._settings_dict['TYPE'] == 'trial' and trials is not None:
134203
self._trial_queue = mp.Queue(1)
135204
self._success_queue = mp.Queue(1)
205+
self._condition_queue = mp.Queue(1)
136206
self._protocol_process = mp.Process(target=base_trial_protocol_run,
137-
args=(self._trial_queue, self._success_queue, trials))
207+
args=(self._trial_queue, self._trial_queue,
208+
self._success_queue, self._settings_dict['STIMULATION']))
138209
elif self._settings_dict['TYPE'] == 'switch':
139210
self._condition_queue = mp.Queue(1)
140211
self._protocol_process = mp.Process(target=base_conditional_switch_protocol_run,
@@ -158,7 +229,7 @@ def end(self):
158229
"""
159230
Ending the process
160231
"""
161-
if self._settings_dict['TYPE'] == 'condition':
232+
if self._settings_dict['TYPE'] == 'switch' or self._settings_dict['TYPE'] == 'supply':
162233
self._condition_queue.close()
163234
elif self._settings_dict['TYPE'] == 'trial':
164235
self._trial_queue.close()
@@ -182,7 +253,7 @@ def put(self, input_p):
182253
self._running = True
183254
self._current_trial = input_p
184255

185-
elif self._settings_dict['TYPE'] == 'condition':
256+
elif self._settings_dict['TYPE'] == 'switch' or self._settings_dict['TYPE'] == 'supply':
186257
if self._condition_queue.empty():
187258
self._condition_queue.put(input_p)
188259

0 commit comments

Comments
 (0)