Skip to content

Commit 406568c

Browse files
committed
update bandit extractors
1 parent 8825b7d commit 406568c

File tree

2 files changed

+106
-93
lines changed

2 files changed

+106
-93
lines changed

projects/task_extractor_map.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,8 @@
44
"nate_optoBiasedChoiceWorld": "projects.nate_optoBiasedChoiceWorld.TrialsOpto",
55
"FPChoiceWorld": "BiasedTrials",
66
"FPROptoChoiceWorld": "projects.alejandro_FPLROptoChoiceWorld.TrialsFPLROpto",
7-
"FPLOptoChoiceWorld": "projects.alejandro_FPLROptoChoiceWorld.TrialsFPLROpto"
7+
"FPLOptoChoiceWorld": "projects.alejandro_FPLROptoChoiceWorld.TrialsFPLROpto",
8+
"_bandit_biasedChoiceWorld": "projects.training_bandit.TrialsBandit",
9+
"_bandit_100_0_biasedChoiceWorld": "projects.training_bandit.TrialsBandit",
10+
"_bandit_alllaser_cued_ephysChoiceWorld": "projects.training_bandit.TrialsLaserBandit"
811
}

projects/training_bandit.py

Lines changed: 102 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,114 +1,104 @@
11
import logging
2-
from collections import OrderedDict
32
import numpy as np
43
from one.alf.io import AlfBunch
54

6-
from ibllib.pipes import tasks
75
import ibllib.io.extractors.training_trials as tt
8-
import ibllib.pipes.training_preprocessing as training_tasks
96
from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes
10-
import ibllib.io.extractors.base
11-
import ibllib.io.raw_data_loaders as rawio
12-
137

148
_logger = logging.getLogger('ibllib')
159

1610

17-
class TrainingBanditTrials(tasks.Task):
18-
priority = 90
19-
level = 0
20-
force = False
21-
signature = {
22-
'input_files': [('_iblrig_taskData.raw.*', 'raw_behavior_data', True),
23-
('_iblrig_taskSettings.raw.*', 'raw_behavior_data', True),
24-
('_iblrig_encoderEvents.raw*', 'raw_behavior_data', True),
25-
('_iblrig_encoderPositions.raw*', 'raw_behavior_data', True)],
26-
'output_files': [('*trials.goCueTrigger_times.npy', 'alf', True),
27-
('*trials.itiDuration.npy', 'alf', False),
28-
('*trials.probabilityRewardLeft', 'alf', True),
29-
('*trials.table.pqt', 'alf', True),
30-
('*wheel.position.npy', 'alf', True),
31-
('*wheel.timestamps.npy', 'alf', True),
32-
('*wheelMoves.intervals.npy', 'alf', True),
33-
('*wheelMoves.peakAmplitude.npy', 'alf', True)]
34-
}
11+
class TrialsLaserBandit(BaseBpodTrialsExtractor):
3512

36-
def _run(self):
37-
"""
38-
Extracts an iblrig training session
39-
"""
40-
trials, wheel, output_files = extract_all(self.session_path, save=True)
41-
if trials is None:
42-
return None
13+
var_names = tt.TrainingTrials.var_names + ('probabilityRewardLeft', 'laserStimulation', 'laserProbability')
14+
save_names = tt.TrainingTrials.save_names + ('_av_trials.probabilityRewardLeft.npy',
15+
'_ibl_trials.laserStimulation.npy', '_av_trials.laserProbability.npy')
4316

44-
return output_files
17+
def _extract(self, extractor_classes=None, **kwargs) -> dict:
4518

19+
base = [BanditRepNum, tt.GoCueTriggerTimes, tt.StimOnTriggerTimes, tt.ItiInTimes, tt.StimOffTriggerTimes,
20+
tt.StimFreezeTriggerTimes, tt.ErrorCueTriggerTimes, LaserBanditTrialsTable, tt.PhasePosQuiescence,
21+
tt.PauseDuration, ProbabilityRewardLeft, BanditLaserStimulation, BanditLaserProbability]
4622

47-
def extract_all(session_path, save=True, bpod_trials=None, settings=None):
48-
"""Extract trials and wheel data.
23+
# Extract common biased choice world datasets
24+
out, _ = run_extractor_classes(
25+
base, session_path=self.session_path, bpod_trials=self.bpod_trials,
26+
settings=self.settings, save=False, task_collection=self.task_collection)
27+
28+
return {k: out[k] for k in self.var_names}
4929

50-
For task versions >= 5.0.0, outputs wheel data and trials.table dataset (+ some extra datasets)
5130

52-
Parameters
53-
----------
54-
session_path : str, pathlib.Path
55-
The path to the session
56-
save : bool
57-
If true save the data files to ALF
58-
bpod_trials : list of dicts
59-
The Bpod trial dicts loaded from the _iblrig_taskData.raw dataset
60-
settings : dict
61-
The Bpod settings loaded from the _iblrig_taskSettings.raw dataset
6231

63-
Returns
64-
-------
65-
A list of extracted data and a list of file paths if save is True (otherwise None)
66-
"""
6732

68-
extractor_type = ibllib.io.extractors.base.get_session_extractor_type(session_path)
69-
_logger.info(f"Extracting {session_path} as {extractor_type}")
70-
bpod_trials = bpod_trials or rawio.load_data(session_path)
71-
settings = settings or rawio.load_settings(session_path)
72-
_logger.info(f'{extractor_type} session on {settings["PYBPOD_BOARD"]}')
33+
class TrialsBandit(BaseBpodTrialsExtractor):
34+
var_names = tt.TrainingTrials.var_names + ('probabilityRewardLeft',)
35+
save_names = tt.TrainingTrials.save_names + ('_av_trials.probabilityRewardLeft.npy',)
7336

74-
if settings is None or settings['IBLRIG_VERSION'] == '':
75-
settings = {'IBLRIG_VERSION': '100.0.0'}
37+
def _extract(self, extractor_classes=None, **kwargs) -> dict:
7638

77-
# check that the extraction works for both the shaping 0-100 and the other one
78-
base = [BanditRepNum, tt.GoCueTriggerTimes, tt.StimOnTriggerTimes, tt.ItiInTimes, tt.StimOffTriggerTimes,
79-
tt.StimFreezeTriggerTimes, tt.ErrorCueTriggerTimes, ProbabilityRewardLeft, BanditTrialsTable]
39+
base = [BanditRepNum, tt.GoCueTriggerTimes, tt.StimOnTriggerTimes, tt.ItiInTimes, tt.StimOffTriggerTimes,
40+
tt.StimFreezeTriggerTimes, tt.ErrorCueTriggerTimes, BanditTrialsTable, tt.PhasePosQuiescence,
41+
tt.PauseDuration, ProbabilityRewardLeft]
8042

81-
trials, files_trials = run_extractor_classes(
82-
base, save=save, session_path=session_path, bpod_trials=bpod_trials, settings=settings)
43+
# Extract common biased choice world datasets
44+
out, _ = run_extractor_classes(
45+
base, session_path=self.session_path, bpod_trials=self.bpod_trials,
46+
settings=self.settings, save=False, task_collection=self.task_collection)
8347

84-
files_wheel = []
85-
wheel = OrderedDict({k: trials.pop(k) for k in tuple(trials.keys()) if 'wheel' in k})
8648

87-
_logger.info('session extracted \n') # timing info in log
49+
return {k: out[k] for k in self.var_names}
8850

89-
return trials, wheel, (files_trials + files_wheel) if save else None
9051

9152

92-
class BanditTrialsTable(tt.TrialsTable):
53+
class BanditTrialsTable(BaseBpodTrialsExtractor):
9354
"""
9455
Extracts the following into a table from Bpod raw data:
9556
intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight,
9657
feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times
9758
Additionally extracts the following wheel data:
98-
wheel_timestamps, wheel_position, wheelMoves_intervals, wheelMoves_peakAmplitude
59+
wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude
9960
"""
10061

101-
def _extract(self, **kwargs):
62+
var_names = tt.TrialsTable.var_names
63+
save_names = tt.TrialsTable.save_names
64+
65+
66+
def _extract(self, extractor_classes=None, **kwargs):
10267
base = [tt.Intervals, tt.GoCueTimes, tt.ResponseTimes, BanditChoice, tt.StimOnOffFreezeTimes, BanditContrastLR,
10368
tt.FeedbackTimes, tt.FeedbackType, tt.RewardVolume, BanditProbabilityLeft, tt.Wheel]
104-
exclude = [
105-
'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position',
106-
'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement'
107-
]
108-
109-
out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings,
110-
save=False)
111-
table = AlfBunch({k: v for k, v in out.items() if k not in exclude})
69+
70+
out, _ = run_extractor_classes(
71+
base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False,
72+
task_collection=self.task_collection)
73+
74+
table = AlfBunch({k: v for k, v in out.items() if k not in self.var_names})
75+
assert len(table.keys()) == 12
76+
77+
return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table')
78+
79+
80+
81+
class LaserBanditTrialsTable(BaseBpodTrialsExtractor):
82+
"""
83+
Extracts the following into a table from Bpod raw data:
84+
intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight,
85+
feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times
86+
Additionally extracts the following wheel data:
87+
wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude
88+
"""
89+
90+
var_names = tt.TrialsTable.var_names
91+
save_names = tt.TrialsTable.save_names
92+
93+
def _extract(self, extractor_classes=None, **kwargs):
94+
base = [tt.Intervals, tt.GoCueTimes, tt.ResponseTimes, BanditChoice, tt.StimOnOffFreezeTimes, BanditContrastLR,
95+
tt.FeedbackTimes, tt.FeedbackType, BanditRewardVolume, BanditProbabilityLeft, tt.Wheel]
96+
97+
out, _ = run_extractor_classes(
98+
base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False,
99+
task_collection=self.task_collection)
100+
101+
table = AlfBunch({k: v for k, v in out.items() if k not in self.var_names})
112102
assert len(table.keys()) == 12
113103

114104
return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table')
@@ -188,22 +178,42 @@ def _extract(self):
188178
return contrastLeft, contrastRight
189179

190180

191-
class TrainingBanditPipeline(tasks.Pipeline):
192-
label = __name__
181+
class BanditRewardVolume(tt.RewardVolume):
182+
"""
183+
Load reward volume delivered for each trial. For trials where the reward was given by laser stimulation
184+
rather than water stimulation set reward volume to 0
185+
"""
186+
187+
def _extract(self):
188+
rewards = super(BanditRewardVolume, self)._extract()
189+
laser = np.array([t['opto_block'] for t in self.bpod_trials]).astype(bool)
190+
rewards[laser] = 0
191+
192+
return rewards
193193

194-
def __init__(self, session_path, **kwargs):
195-
super(TrainingBanditPipeline, self).__init__(session_path, **kwargs)
196-
tasks = OrderedDict()
197-
self.session_path = session_path
198-
# level 0
199-
tasks['TrainingRegisterRaw'] = training_tasks.TrainingRegisterRaw(self.session_path)
200-
tasks['TrainingBanditTrials'] = TrainingBanditTrials(self.session_path)
201-
tasks['TrainingVideoCompress'] = training_tasks.TrainingVideoCompress(self.session_path)
202-
tasks['TrainingAudio'] = training_tasks.TrainingAudio(self.session_path)
203-
# level 1
204-
tasks['TrainingDLC'] = training_tasks.TrainingDLC(
205-
self.session_path, parents=[tasks['TrainingVideoCompress']])
206-
self.tasks = tasks
207194

195+
class BanditLaserProbability(BaseBpodTrialsExtractor):
196+
save_names = '_av_trials.laserProbability.npy'
197+
var_names = 'laserProbability'
198+
199+
def _extract(self):
200+
laser = np.array([t['opto_block'] for t in self.bpod_trials]).astype(int)
201+
return laser
202+
203+
204+
class BanditLaserStimulation(BaseBpodTrialsExtractor):
205+
"""
206+
Get the trials where laser reward stimulation was given. Laser stimulation given when task was in laser block and feedback
207+
is correct
208+
"""
209+
210+
save_names = '_ibl_trials.laserStimulation.npy'
211+
var_names = 'laserStimulation'
212+
213+
def _extract(self):
214+
reward = np.array([~np.isnan(t['behavior_data']['States timestamps']['reward'][0][0]) for t in
215+
self.bpod_trials]).astype(bool)
216+
laser = np.array([t['opto_block'] for t in self.bpod_trials]).astype(int)
217+
laser[~reward] = 0
218+
return laser
208219

209-
__pipeline__ = TrainingBanditPipeline

0 commit comments

Comments
 (0)