55from iblutil .util import Bunch
66import brainbox .behavior .pyschofit as psy
77import logging
8+ import matplotlib
9+ import matplotlib .pyplot as plt
10+ import seaborn as sns
11+ import pandas as pd
12+
813_logger = logging .getLogger ('ibllib' )
914
15+ TRIALS_KEYS = ['contrastLeft' ,
16+ 'contrastRight' ,
17+ 'feedbackType' ,
18+ 'probabilityLeft' ,
19+ 'choice' ,
20+ 'response_times' ,
21+ 'stimOn_times' ]
22+
1023
1124def get_lab_training_status (lab , date = None , details = True , one = None ):
1225 """
@@ -303,14 +316,14 @@ def concatenate_trials(trials):
303316 """
304317 Concatenate trials from different training sessions
305318
306- :param trials: dict containing trials objects from three consective training sessions,
319+ :param trials: dict containing trials objects from three consecutive training sessions,
307320 keys are session dates
308321 :type trials: Bunch
309322 :return: trials object with data concatenated over three training sessions
310323 :rtype: dict
311324 """
312325 trials_all = Bunch ()
313- for k in trials [ list ( trials . keys ())[ 0 ]]. keys () :
326+ for k in TRIALS_KEYS :
314327 trials_all [k ] = np .concatenate (list (trials [kk ][k ] for kk in trials .keys ()))
315328
316329 return trials_all
@@ -395,6 +408,35 @@ def compute_performance_easy(trials):
395408 return np .sum (trials ['feedbackType' ][easy_trials ] == 1 ) / easy_trials .shape [0 ]
396409
397410
411+ def compute_performance (trials , signed_contrast = None , block = None ):
412+ """
413+ Compute performance on all trials at each contrast level from trials object
414+
415+ :param trials: trials object that must contain contrastLeft, contrastRight and feedbackType
416+ keys
417+ :type trials: dict
418+ returns: float containing performance on easy contrast trials
419+ """
420+ if signed_contrast is None :
421+ signed_contrast = get_signed_contrast (trials )
422+
423+ if block is None :
424+ block_idx = np .full (trials .probabilityLeft .shape , True , dtype = bool )
425+ else :
426+ block_idx = trials .probabilityLeft == block
427+
428+ if not np .any (block_idx ):
429+ return np .nan * np .zeros (2 )
430+
431+ contrasts , n_contrasts = np .unique (signed_contrast [block_idx ], return_counts = True )
432+ rightward = trials .choice == - 1
433+ # Calculate the proportion rightward for each contrast type
434+ prob_choose_right = np .vectorize (lambda x : np .mean (rightward [(x == signed_contrast ) &
435+ block_idx ]))(contrasts )
436+
437+ return prob_choose_right , contrasts , n_contrasts
438+
439+
398440def compute_n_trials (trials ):
399441 """
400442 Compute number of trials in trials object
@@ -418,6 +460,7 @@ def compute_psychometric(trials, signed_contrast=None, block=None):
418460 :type block: float
419461 :return: array of psychometric fit parameters - bias, threshold, lapse high, lapse low
420462 """
463+
421464 if signed_contrast is None :
422465 signed_contrast = get_signed_contrast (trials )
423466
@@ -429,11 +472,7 @@ def compute_psychometric(trials, signed_contrast=None, block=None):
429472 if not np .any (block_idx ):
430473 return np .nan * np .zeros (4 )
431474
432- contrasts , n_contrasts = np .unique (signed_contrast [block_idx ], return_counts = True )
433- rightward = trials .choice == - 1
434- # Calculate the proportion rightward for each contrast type
435- prob_choose_right = np .vectorize (lambda x : np .mean (rightward [(x == signed_contrast ) &
436- block_idx ]))(contrasts )
475+ prob_choose_right , contrasts , n_contrasts = compute_performance (trials , signed_contrast = signed_contrast , block = block )
437476
438477 psych , _ = psy .mle_fit_psycho (
439478 np .vstack ([contrasts , n_contrasts , prob_choose_right ]),
@@ -471,6 +510,31 @@ def compute_median_reaction_time(trials, stim_on_type='stimOn_times', signed_con
471510 return reaction_time
472511
473512
513+ def compute_reaction_time (trials , stim_on_type = 'stimOn_times' , signed_contrast = None , block = None ):
514+ """
515+ Compute median reaction time for all contrasts
516+ :param trials: trials object that must contain response_times and stimOn_times
517+ :param stim_on_type:
518+ :param signed_contrast:
519+ :param block:
520+ :return:
521+ """
522+
523+ if signed_contrast is None :
524+ signed_contrast = get_signed_contrast (trials )
525+
526+ if block is None :
527+ block_idx = np .full (trials .probabilityLeft .shape , True , dtype = bool )
528+ else :
529+ block_idx = trials .probabilityLeft == block
530+
531+ contrasts , n_contrasts = np .unique (signed_contrast [block_idx ], return_counts = True )
532+ reaction_time = np .vectorize (lambda x : np .nanmedian ((trials .response_times - trials [stim_on_type ])
533+ [(x == signed_contrast ) & block_idx ]))(contrasts )
534+
535+ return reaction_time , contrasts , n_contrasts
536+
537+
474538def criterion_1a (psych , n_trials , perf_easy ):
475539 """
476540 Returns bool indicating whether criterion for trained_1a is met. All criteria documented here
@@ -508,3 +572,126 @@ def criterion_delay(n_trials, perf_easy):
508572 """
509573 criterion = np .any (n_trials > 400 ) and np .any (perf_easy > 0.9 )
510574 return criterion
575+
576+
577+ def plot_psychometric (trials , ax = None , title = None , ** kwargs ):
578+ """
579+ Function to plot pyschometric curve plots a la datajoint webpage
580+ :param trials:
581+ :return:
582+ """
583+
584+ signed_contrast = get_signed_contrast (trials )
585+ contrasts_fit = np .arange (- 100 , 100 )
586+
587+ prob_right_50 , contrasts , _ = compute_performance (trials , signed_contrast = signed_contrast , block = 0.5 )
588+ pars_50 = compute_psychometric (trials , signed_contrast = signed_contrast , block = 0.5 )
589+ prob_right_fit_50 = psy .erf_psycho_2gammas (pars_50 , contrasts_fit )
590+
591+ prob_right_20 , contrasts , _ = compute_performance (trials , signed_contrast = signed_contrast , block = 0.2 )
592+ pars_20 = compute_psychometric (trials , signed_contrast = signed_contrast , block = 0.2 )
593+ prob_right_fit_20 = psy .erf_psycho_2gammas (pars_20 , contrasts_fit )
594+
595+ prob_right_80 , contrasts , _ = compute_performance (trials , signed_contrast = signed_contrast , block = 0.8 )
596+ pars_80 = compute_psychometric (trials , signed_contrast = signed_contrast , block = 0.8 )
597+ prob_right_fit_80 = psy .erf_psycho_2gammas (pars_80 , contrasts_fit )
598+
599+ cmap = sns .diverging_palette (20 , 220 , n = 3 , center = "dark" )
600+
601+ if not ax :
602+ fig , ax = plt .subplots (** kwargs )
603+ else :
604+ fig = plt .gcf ()
605+
606+ # TODO error bars
607+
608+ fit_50 = ax .plot (contrasts_fit , prob_right_fit_50 , color = cmap [1 ])
609+ data_50 = ax .scatter (contrasts , prob_right_50 , color = cmap [1 ])
610+ fit_20 = ax .plot (contrasts_fit , prob_right_fit_20 , color = cmap [0 ])
611+ data_20 = ax .scatter (contrasts , prob_right_20 , color = cmap [0 ])
612+ fit_80 = ax .plot (contrasts_fit , prob_right_fit_80 , color = cmap [2 ])
613+ data_80 = ax .scatter (contrasts , prob_right_80 , color = cmap [2 ])
614+ ax .legend ([fit_50 [0 ], data_50 , fit_20 [0 ], data_20 , fit_80 [0 ], data_80 ],
615+ ['p_left=0.5 fit' , 'p_left=0.5 data' , 'p_left=0.2 fit' , 'p_left=0.2 data' , 'p_left=0.8 fit' , 'p_left=0.8 data' ],
616+ loc = 'upper left' )
617+ ax .set_ylim (- 0.05 , 1.05 )
618+ ax .set_ylabel ('Probability choosing right' )
619+ ax .set_xlabel ('Contrasts' )
620+ if title :
621+ ax .set_title (title )
622+
623+ return fig , ax
624+
625+
626+ def plot_reaction_time (trials , ax = None , title = None , ** kwargs ):
627+ """
628+ Function to plot reaction time against contrast a la datajoint webpage (inversed for some reason??)
629+ :param trials:
630+ :return:
631+ """
632+
633+ signed_contrast = get_signed_contrast (trials )
634+ reaction_50 , contrasts , _ = compute_reaction_time (trials , signed_contrast = signed_contrast , block = 0.5 )
635+ reaction_20 , contrasts , _ = compute_reaction_time (trials , signed_contrast = signed_contrast , block = 0.2 )
636+ reaction_80 , contrasts , _ = compute_reaction_time (trials , signed_contrast = signed_contrast , block = 0.8 )
637+
638+ cmap = sns .diverging_palette (20 , 220 , n = 3 , center = "dark" )
639+
640+ if not ax :
641+ fig , ax = plt .subplots (** kwargs )
642+ else :
643+ fig = plt .gcf ()
644+
645+ data_50 = ax .plot (contrasts , reaction_50 , '-o' , color = cmap [1 ])
646+ data_20 = ax .plot (contrasts , reaction_20 , '-o' , color = cmap [0 ])
647+ data_80 = ax .plot (contrasts , reaction_80 , '-o' , color = cmap [2 ])
648+
649+ # TODO error bars
650+
651+ ax .legend ([data_50 [0 ], data_20 [0 ], data_80 [0 ]],
652+ ['p_left=0.5 data' , 'p_left=0.2 data' , 'p_left=0.8 data' ],
653+ loc = 'upper left' )
654+ ax .set_ylabel ('Reaction time (s)' )
655+ ax .set_xlabel ('Contrasts' )
656+
657+ if title :
658+ ax .set_title (title )
659+
660+ return fig , ax
661+
662+
663+ def plot_reaction_time_over_trials (trials , stim_on_type = 'stimOn_times' , ax = None , title = None , ** kwargs ):
664+ """
665+ Function to plot reaction time with trial number a la datajoint webpage
666+
667+ :param trials:
668+ :param stim_on_type:
669+ :param ax:
670+ :param title:
671+ :param kwargs:
672+ :return:
673+ """
674+
675+ reaction_time = pd .DataFrame ()
676+ reaction_time ['reaction_time' ] = trials .response_times - trials [stim_on_type ]
677+ reaction_time .index = reaction_time .index + 1
678+ reaction_time_rolled = reaction_time ['reaction_time' ].rolling (window = 10 ).median ()
679+ reaction_time_rolled = reaction_time_rolled .where ((pd .notnull (reaction_time_rolled )), None )
680+ reaction_time = reaction_time .where ((pd .notnull (reaction_time )), None )
681+
682+ if not ax :
683+ fig , ax = plt .subplots (** kwargs )
684+ else :
685+ fig = plt .gcf ()
686+
687+ ax .scatter (np .arange (len (reaction_time .values )), reaction_time .values , s = 16 , color = 'darkgray' )
688+ ax .plot (np .arange (len (reaction_time_rolled .values )), reaction_time_rolled .values , color = 'k' , linewidth = 2 )
689+ ax .set_yscale ('log' )
690+ ax .set_ylim (0.1 , 100 )
691+ ax .yaxis .set_major_formatter (matplotlib .ticker .ScalarFormatter ())
692+ ax .set_ylabel ('Reaction time (s)' )
693+ ax .set_xlabel ('Trial number' )
694+ if title :
695+ ax .set_title (title )
696+
697+ return fig , ax
0 commit comments