Skip to content

Commit 279e96e

Browse files
authored
Merge pull request #478 from int-brain-lab/training_plots
Training plots
2 parents 1325f31 + 5b7bdc1 commit 279e96e

File tree

7 files changed

+638
-23
lines changed

7 files changed

+638
-23
lines changed

brainbox/behavior/training.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def compute_training_info(trials, trials_all):
349349
perf_easy = np.array([compute_performance_easy(trials[k]) for k in trials.keys()])
350350
n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()])
351351
psych = compute_psychometric(trials_all, signed_contrast=signed_contrast)
352-
rt = compute_median_reaction_time(trials_all, signed_contrast=signed_contrast)
352+
rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast)
353353

354354
return perf_easy, n_trials, psych, rt
355355

@@ -376,7 +376,7 @@ def compute_bias_info(trials, trials_all):
376376
n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()])
377377
psych_20 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.2)
378378
psych_80 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.8)
379-
rt = compute_median_reaction_time(trials_all, signed_contrast=signed_contrast)
379+
rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast)
380380

381381
return perf_easy, n_trials, psych_20, psych_80, rt
382382

@@ -452,7 +452,7 @@ def compute_n_trials(trials):
452452
return trials['choice'].shape[0]
453453

454454

455-
def compute_psychometric(trials, signed_contrast=None, block=None):
455+
def compute_psychometric(trials, signed_contrast=None, block=None, plotting=False):
456456
"""
457457
Compute psychometric fit parameters for trials object
458458
@@ -479,17 +479,27 @@ def compute_psychometric(trials, signed_contrast=None, block=None):
479479
prob_choose_right, contrasts, n_contrasts = compute_performance(trials, signed_contrast=signed_contrast, block=block,
480480
prob_right=True)
481481

482-
psych, _ = psy.mle_fit_psycho(
483-
np.vstack([contrasts, n_contrasts, prob_choose_right]),
484-
P_model='erf_psycho_2gammas',
485-
parstart=np.array([np.mean(contrasts), 20., 0.05, 0.05]),
486-
parmin=np.array([np.min(contrasts), 0., 0., 0.]),
487-
parmax=np.array([np.max(contrasts), 100., 1, 1]))
482+
if plotting:
483+
psych, _ = psy.mle_fit_psycho(
484+
np.vstack([contrasts, n_contrasts, prob_choose_right]),
485+
P_model='erf_psycho_2gammas',
486+
parstart=np.array([0., 40., 0.1, 0.1]),
487+
parmin=np.array([-50., 10., 0., 0.]),
488+
parmax=np.array([50., 50., 0.2, 0.2]),
489+
nfits=10)
490+
else:
491+
492+
psych, _ = psy.mle_fit_psycho(
493+
np.vstack([contrasts, n_contrasts, prob_choose_right]),
494+
P_model='erf_psycho_2gammas',
495+
parstart=np.array([np.mean(contrasts), 20., 0.05, 0.05]),
496+
parmin=np.array([np.min(contrasts), 0., 0., 0.]),
497+
parmax=np.array([np.max(contrasts), 100., 1, 1]))
488498

489499
return psych
490500

491501

492-
def compute_median_reaction_time(trials, stim_on_type='stimOn_times', signed_contrast=None):
502+
def compute_median_reaction_time(trials, stim_on_type='stimOn_times', contrast=None, signed_contrast=None):
493503
"""
494504
Compute median reaction time on zero contrast trials from trials object
495505
@@ -505,10 +515,15 @@ def compute_median_reaction_time(trials, stim_on_type='stimOn_times', signed_con
505515
"""
506516
if signed_contrast is None:
507517
signed_contrast = get_signed_contrast(trials)
508-
zero_trials = (trials.response_times - trials[stim_on_type])[signed_contrast == 0]
509-
if np.any(zero_trials):
518+
519+
if contrast is None:
520+
contrast_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool)
521+
else:
522+
contrast_idx = signed_contrast == contrast
523+
524+
if np.any(contrast_idx):
510525
reaction_time = np.nanmedian((trials.response_times - trials[stim_on_type])
511-
[signed_contrast == 0])
526+
[contrast_idx])
512527
else:
513528
reaction_time = np.nan
514529

@@ -590,15 +605,15 @@ def plot_psychometric(trials, ax=None, title=None, **kwargs):
590605
contrasts_fit = np.arange(-100, 100)
591606

592607
prob_right_50, contrasts_50, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.5, prob_right=True)
593-
pars_50 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.5)
608+
pars_50 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.5, plotting=True)
594609
prob_right_fit_50 = psy.erf_psycho_2gammas(pars_50, contrasts_fit)
595610

596611
prob_right_20, contrasts_20, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.2, prob_right=True)
597-
pars_20 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.2)
612+
pars_20 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.2, plotting=True)
598613
prob_right_fit_20 = psy.erf_psycho_2gammas(pars_20, contrasts_fit)
599614

600615
prob_right_80, contrasts_80, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.8, prob_right=True)
601-
pars_80 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.8)
616+
pars_80 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.8, plotting=True)
602617
prob_right_fit_80 = psy.erf_psycho_2gammas(pars_80, contrasts_fit)
603618

604619
cmap = sns.diverging_palette(20, 220, n=3, center="dark")

brainbox/tests/test_behavior.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def test_concatenate_and_computations(self):
152152
perf_easy = np.array([train.compute_performance_easy(trials[k]) for k in trials.keys()])
153153
n_trials = np.array([train.compute_n_trials(trials[k]) for k in trials.keys()])
154154
psych = train.compute_psychometric(trials_all)
155-
rt = train.compute_median_reaction_time(trials_all)
155+
rt = train.compute_median_reaction_time(trials_all, contrast=0)
156156
np.testing.assert_allclose(perf_easy, [0.91489362, 0.9, 0.90853659])
157157
np.testing.assert_array_equal(n_trials, [617, 532, 719])
158158
np.testing.assert_allclose(psych, [4.04487042, 21.6293942, 1.91451396e-02, 1.72669957e-01],

ibllib/pipes/ephys_preprocessing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ibllib.io.extractors import ephys_fpga, ephys_passive, camera
2323
from ibllib.pipes import tasks
2424
from ibllib.pipes.training_preprocessing import TrainingRegisterRaw as EphysRegisterRaw
25+
from ibllib.pipes.training_preprocessing import TrainingStatus as EphysTrainingStatus
2526
from ibllib.pipes.misc import create_alyx_probe_insertions
2627
from ibllib.qc.alignment_qc import get_aligned_channels
2728
from ibllib.qc.task_extractors import TaskQCExtractor
@@ -1287,6 +1288,7 @@ def __init__(self, session_path=None, **kwargs):
12871288
self.session_path, parents=[tasks["EphysVideoCompress"], tasks["EphysPulses"], tasks["EphysTrials"]])
12881289
tasks["EphysCellsQc"] = EphysCellsQc(self.session_path, parents=[tasks["SpikeSorting"]])
12891290
tasks["EphysDLC"] = EphysDLC(self.session_path, parents=[tasks["EphysVideoCompress"]])
1291+
tasks['EphysTrainingStatus'] = EphysTrainingStatus(self.session_path, parents=[tasks["EphysTrials"]])
12901292
# level 3
12911293
tasks["EphysPostDLC"] = EphysPostDLC(self.session_path, parents=[tasks["EphysDLC"], tasks["EphysTrials"],
12921294
tasks["EphysVideoSyncQc"]])

ibllib/pipes/training_preprocessing.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from collections import OrderedDict
33

4-
from ibllib.pipes import tasks
4+
from ibllib.pipes import tasks, training_status
55
from ibllib.io import ffmpeg
66
from ibllib.io.extractors.base import get_session_extractor_type
77
from ibllib.io.extractors import training_audio, bpod_trials, camera
@@ -107,6 +107,27 @@ def _run(self):
107107
pass
108108

109109

110+
class TrainingStatus(tasks.Task):
111+
priority = 90
112+
level = 1
113+
force = False
114+
signature = {
115+
'input_files': [('_iblrig_taskData.raw.*', 'raw_behavior_data', True),
116+
('_iblrig_taskSettings.raw.*', 'raw_behavior_data', True),
117+
('*trials.table.pqt', 'alf', True)],
118+
'output_files': []
119+
}
120+
121+
def _run(self, upload=True):
122+
"""
123+
Extracts training status for subject
124+
"""
125+
df = training_status.get_latest_training_information(self.session_path, self.one)
126+
training_status.make_plots(self.session_path, self.one, df=df, save=True, upload=upload)
127+
output_files = []
128+
return output_files
129+
130+
110131
class TrainingExtractionPipeline(tasks.Pipeline):
111132
label = __name__
112133

@@ -120,6 +141,7 @@ def __init__(self, session_path, **kwargs):
120141
tasks['TrainingVideoCompress'] = TrainingVideoCompress(self.session_path)
121142
tasks['TrainingAudio'] = TrainingAudio(self.session_path)
122143
# level 1
144+
tasks['TrainingStatus'] = TrainingStatus(self.session_path, parents=[tasks['TrainingTrials']])
123145
tasks['TrainingDLC'] = TrainingDLC(
124146
self.session_path, parents=[tasks['TrainingVideoCompress']])
125147
self.tasks = tasks

0 commit comments

Comments
 (0)