Skip to content

Commit 8adae0a

Browse files
authored
Merge pull request #437 from int-brain-lab/develop
Release 2.8.0
2 parents 0ce0b86 + a61aa50 commit 8adae0a

File tree

14 files changed

+924
-140
lines changed

14 files changed

+924
-140
lines changed

brainbox/behavior/training.py

Lines changed: 194 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,21 @@
55
from iblutil.util import Bunch
66
import brainbox.behavior.pyschofit as psy
77
import logging
8+
import matplotlib
9+
import matplotlib.pyplot as plt
10+
import seaborn as sns
11+
import pandas as pd
12+
813
_logger = logging.getLogger('ibllib')
914

15+
TRIALS_KEYS = ['contrastLeft',
16+
'contrastRight',
17+
'feedbackType',
18+
'probabilityLeft',
19+
'choice',
20+
'response_times',
21+
'stimOn_times']
22+
1023

1124
def get_lab_training_status(lab, date=None, details=True, one=None):
1225
"""
@@ -303,14 +316,14 @@ def concatenate_trials(trials):
303316
"""
304317
Concatenate trials from different training sessions
305318
306-
:param trials: dict containing trials objects from three consective training sessions,
319+
:param trials: dict containing trials objects from three consecutive training sessions,
307320
keys are session dates
308321
:type trials: Bunch
309322
:return: trials object with data concatenated over three training sessions
310323
:rtype: dict
311324
"""
312325
trials_all = Bunch()
313-
for k in trials[list(trials.keys())[0]].keys():
326+
for k in TRIALS_KEYS:
314327
trials_all[k] = np.concatenate(list(trials[kk][k] for kk in trials.keys()))
315328

316329
return trials_all
@@ -395,6 +408,35 @@ def compute_performance_easy(trials):
395408
return np.sum(trials['feedbackType'][easy_trials] == 1) / easy_trials.shape[0]
396409

397410

411+
def compute_performance(trials, signed_contrast=None, block=None):
412+
"""
413+
Compute performance on all trials at each contrast level from trials object
414+
415+
:param trials: trials object that must contain contrastLeft, contrastRight and feedbackType
416+
keys
417+
:type trials: dict
418+
returns: float containing performance on easy contrast trials
419+
"""
420+
if signed_contrast is None:
421+
signed_contrast = get_signed_contrast(trials)
422+
423+
if block is None:
424+
block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool)
425+
else:
426+
block_idx = trials.probabilityLeft == block
427+
428+
if not np.any(block_idx):
429+
return np.nan * np.zeros(2)
430+
431+
contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True)
432+
rightward = trials.choice == -1
433+
# Calculate the proportion rightward for each contrast type
434+
prob_choose_right = np.vectorize(lambda x: np.mean(rightward[(x == signed_contrast) &
435+
block_idx]))(contrasts)
436+
437+
return prob_choose_right, contrasts, n_contrasts
438+
439+
398440
def compute_n_trials(trials):
399441
"""
400442
Compute number of trials in trials object
@@ -418,6 +460,7 @@ def compute_psychometric(trials, signed_contrast=None, block=None):
418460
:type block: float
419461
:return: array of psychometric fit parameters - bias, threshold, lapse high, lapse low
420462
"""
463+
421464
if signed_contrast is None:
422465
signed_contrast = get_signed_contrast(trials)
423466

@@ -429,11 +472,7 @@ def compute_psychometric(trials, signed_contrast=None, block=None):
429472
if not np.any(block_idx):
430473
return np.nan * np.zeros(4)
431474

432-
contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True)
433-
rightward = trials.choice == -1
434-
# Calculate the proportion rightward for each contrast type
435-
prob_choose_right = np.vectorize(lambda x: np.mean(rightward[(x == signed_contrast) &
436-
block_idx]))(contrasts)
475+
prob_choose_right, contrasts, n_contrasts = compute_performance(trials, signed_contrast=signed_contrast, block=block)
437476

438477
psych, _ = psy.mle_fit_psycho(
439478
np.vstack([contrasts, n_contrasts, prob_choose_right]),
@@ -471,6 +510,31 @@ def compute_median_reaction_time(trials, stim_on_type='stimOn_times', signed_con
471510
return reaction_time
472511

473512

513+
def compute_reaction_time(trials, stim_on_type='stimOn_times', signed_contrast=None, block=None):
514+
"""
515+
Compute median reaction time for all contrasts
516+
:param trials: trials object that must contain response_times and stimOn_times
517+
:param stim_on_type:
518+
:param signed_contrast:
519+
:param block:
520+
:return:
521+
"""
522+
523+
if signed_contrast is None:
524+
signed_contrast = get_signed_contrast(trials)
525+
526+
if block is None:
527+
block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool)
528+
else:
529+
block_idx = trials.probabilityLeft == block
530+
531+
contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True)
532+
reaction_time = np.vectorize(lambda x: np.nanmedian((trials.response_times - trials[stim_on_type])
533+
[(x == signed_contrast) & block_idx]))(contrasts)
534+
535+
return reaction_time, contrasts, n_contrasts
536+
537+
474538
def criterion_1a(psych, n_trials, perf_easy):
475539
"""
476540
Returns bool indicating whether criterion for trained_1a is met. All criteria documented here
@@ -508,3 +572,126 @@ def criterion_delay(n_trials, perf_easy):
508572
"""
509573
criterion = np.any(n_trials > 400) and np.any(perf_easy > 0.9)
510574
return criterion
575+
576+
577+
def plot_psychometric(trials, ax=None, title=None, **kwargs):
578+
"""
579+
Function to plot pyschometric curve plots a la datajoint webpage
580+
:param trials:
581+
:return:
582+
"""
583+
584+
signed_contrast = get_signed_contrast(trials)
585+
contrasts_fit = np.arange(-100, 100)
586+
587+
prob_right_50, contrasts, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.5)
588+
pars_50 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.5)
589+
prob_right_fit_50 = psy.erf_psycho_2gammas(pars_50, contrasts_fit)
590+
591+
prob_right_20, contrasts, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.2)
592+
pars_20 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.2)
593+
prob_right_fit_20 = psy.erf_psycho_2gammas(pars_20, contrasts_fit)
594+
595+
prob_right_80, contrasts, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.8)
596+
pars_80 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.8)
597+
prob_right_fit_80 = psy.erf_psycho_2gammas(pars_80, contrasts_fit)
598+
599+
cmap = sns.diverging_palette(20, 220, n=3, center="dark")
600+
601+
if not ax:
602+
fig, ax = plt.subplots(**kwargs)
603+
else:
604+
fig = plt.gcf()
605+
606+
# TODO error bars
607+
608+
fit_50 = ax.plot(contrasts_fit, prob_right_fit_50, color=cmap[1])
609+
data_50 = ax.scatter(contrasts, prob_right_50, color=cmap[1])
610+
fit_20 = ax.plot(contrasts_fit, prob_right_fit_20, color=cmap[0])
611+
data_20 = ax.scatter(contrasts, prob_right_20, color=cmap[0])
612+
fit_80 = ax.plot(contrasts_fit, prob_right_fit_80, color=cmap[2])
613+
data_80 = ax.scatter(contrasts, prob_right_80, color=cmap[2])
614+
ax.legend([fit_50[0], data_50, fit_20[0], data_20, fit_80[0], data_80],
615+
['p_left=0.5 fit', 'p_left=0.5 data', 'p_left=0.2 fit', 'p_left=0.2 data', 'p_left=0.8 fit', 'p_left=0.8 data'],
616+
loc='upper left')
617+
ax.set_ylim(-0.05, 1.05)
618+
ax.set_ylabel('Probability choosing right')
619+
ax.set_xlabel('Contrasts')
620+
if title:
621+
ax.set_title(title)
622+
623+
return fig, ax
624+
625+
626+
def plot_reaction_time(trials, ax=None, title=None, **kwargs):
627+
"""
628+
Function to plot reaction time against contrast a la datajoint webpage (inversed for some reason??)
629+
:param trials:
630+
:return:
631+
"""
632+
633+
signed_contrast = get_signed_contrast(trials)
634+
reaction_50, contrasts, _ = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.5)
635+
reaction_20, contrasts, _ = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.2)
636+
reaction_80, contrasts, _ = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.8)
637+
638+
cmap = sns.diverging_palette(20, 220, n=3, center="dark")
639+
640+
if not ax:
641+
fig, ax = plt.subplots(**kwargs)
642+
else:
643+
fig = plt.gcf()
644+
645+
data_50 = ax.plot(contrasts, reaction_50, '-o', color=cmap[1])
646+
data_20 = ax.plot(contrasts, reaction_20, '-o', color=cmap[0])
647+
data_80 = ax.plot(contrasts, reaction_80, '-o', color=cmap[2])
648+
649+
# TODO error bars
650+
651+
ax.legend([data_50[0], data_20[0], data_80[0]],
652+
['p_left=0.5 data', 'p_left=0.2 data', 'p_left=0.8 data'],
653+
loc='upper left')
654+
ax.set_ylabel('Reaction time (s)')
655+
ax.set_xlabel('Contrasts')
656+
657+
if title:
658+
ax.set_title(title)
659+
660+
return fig, ax
661+
662+
663+
def plot_reaction_time_over_trials(trials, stim_on_type='stimOn_times', ax=None, title=None, **kwargs):
664+
"""
665+
Function to plot reaction time with trial number a la datajoint webpage
666+
667+
:param trials:
668+
:param stim_on_type:
669+
:param ax:
670+
:param title:
671+
:param kwargs:
672+
:return:
673+
"""
674+
675+
reaction_time = pd.DataFrame()
676+
reaction_time['reaction_time'] = trials.response_times - trials[stim_on_type]
677+
reaction_time.index = reaction_time.index + 1
678+
reaction_time_rolled = reaction_time['reaction_time'].rolling(window=10).median()
679+
reaction_time_rolled = reaction_time_rolled.where((pd.notnull(reaction_time_rolled)), None)
680+
reaction_time = reaction_time.where((pd.notnull(reaction_time)), None)
681+
682+
if not ax:
683+
fig, ax = plt.subplots(**kwargs)
684+
else:
685+
fig = plt.gcf()
686+
687+
ax.scatter(np.arange(len(reaction_time.values)), reaction_time.values, s=16, color='darkgray')
688+
ax.plot(np.arange(len(reaction_time_rolled.values)), reaction_time_rolled.values, color='k', linewidth=2)
689+
ax.set_yscale('log')
690+
ax.set_ylim(0.1, 100)
691+
ax.yaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter())
692+
ax.set_ylabel('Reaction time (s)')
693+
ax.set_xlabel('Trial number')
694+
if title:
695+
ax.set_title(title)
696+
697+
return fig, ax

0 commit comments

Comments
 (0)