1616from pathlib import Path
1717import matplotlib .pyplot as plt
1818import matplotlib .dates as mdates
19+ from matplotlib .lines import Line2D
1920from datetime import datetime
21+ import seaborn as sns
2022
2123one = ONE ()
2224
@@ -309,12 +311,26 @@ def get_training_info_for_session(session_paths, one):
309311 continue
310312
311313 sess_dict ['performance' ], sess_dict ['contrasts' ], _ = training .compute_performance (trials , prob_right = True )
314+ if sess_dict ['task_protocol' ] == 'training' :
315+ sess_dict ['bias_50' ], sess_dict ['thres_50' ], sess_dict ['lapsehigh_50' ], sess_dict ['lapselow_50' ] = \
316+ training .compute_psychometric (trials )
317+ sess_dict ['bias_20' ], sess_dict ['thres_20' ], sess_dict ['lapsehigh_20' ], sess_dict ['lapselow_20' ] = \
318+ (np .nan , np .nan , np .nan , np .nan )
319+ sess_dict ['bias_80' ], sess_dict ['thres_80' ], sess_dict ['lapsehigh_80' ], sess_dict ['lapselow_80' ] = \
320+ (np .nan , np .nan , np .nan , np .nan )
321+ else :
322+ sess_dict ['bias_50' ], sess_dict ['thres_50' ], sess_dict ['lapsehigh_50' ], sess_dict ['lapselow_50' ] = \
323+ training .compute_psychometric (trials , block = 0.5 )
324+ sess_dict ['bias_20' ], sess_dict ['thres_20' ], sess_dict ['lapsehigh_20' ], sess_dict ['lapselow_20' ] = \
325+ training .compute_psychometric (trials , block = 0.2 )
326+ sess_dict ['bias_80' ], sess_dict ['thres_80' ], sess_dict ['lapsehigh_80' ], sess_dict ['lapselow_80' ] = \
327+ training .compute_psychometric (trials , block = 0.8 )
328+
312329 sess_dict ['performance_easy' ] = training .compute_performance_easy (trials )
313330 sess_dict ['reaction_time' ] = training .compute_median_reaction_time (trials )
314331 sess_dict ['n_trials' ] = training .compute_n_trials (trials )
315332 sess_dict ['sess_duration' ], sess_dict ['n_delay' ], sess_dict ['location' ] = \
316333 compute_session_duration_delay_location (session_path )
317- sess_dict ['task_protocol' ] = get_session_extractor_type (session_path )
318334 sess_dict ['training_status' ] = 'not_computed'
319335
320336 sess_dicts .append (sess_dict )
@@ -328,6 +344,11 @@ def get_training_info_for_session(session_paths, one):
328344 print (f'{ len (sess_dicts )} sessions being combined for date { sess_dicts [0 ]["date" ]} ' )
329345 combined_trials = load_combined_trials (session_paths , one )
330346 performance , contrasts , _ = training .compute_performance (combined_trials , prob_right = True )
347+ psychs = {}
348+ psychs ['50' ] = training .compute_psychometric (trials , block = 0.5 )
349+ psychs ['20' ] = training .compute_psychometric (trials , block = 0.2 )
350+ psychs ['80' ] = training .compute_psychometric (trials , block = 0.8 )
351+
331352 performance_easy = training .compute_performance_easy (combined_trials )
332353 reaction_time = training .compute_median_reaction_time (combined_trials )
333354 n_trials = training .compute_n_trials (combined_trials )
@@ -344,6 +365,12 @@ def get_training_info_for_session(session_paths, one):
344365 sess_dict ['combined_sess_duration' ] = sess_duration
345366 sess_dict ['combined_n_delay' ] = n_delay
346367
368+ for bias in [50 , 20 , 80 ]:
369+ sess_dict [f'combined_bias_{ bias } ' ] = psychs [f'{ bias } ' ][0 ]
370+ sess_dict [f'combined_thres_{ bias } ' ] = psychs [f'{ bias } ' ][1 ]
371+ sess_dict [f'combined_lapsehigh_{ bias } ' ] = psychs [f'{ bias } ' ][2 ]
372+ sess_dict [f'combined_lapselow_{ bias } ' ] = psychs [f'{ bias } ' ][3 ]
373+
347374 # Case where two sessions on same day with different number of contrasts! Oh boy
348375 if sess_dict ['combined_performance' ].size != sess_dict ['performance' ].size :
349376 sess_dict ['performance' ] = \
@@ -363,6 +390,12 @@ def get_training_info_for_session(session_paths, one):
363390 sess_dict ['combined_sess_duration' ] = sess_dict ['sess_duration' ]
364391 sess_dict ['combined_n_delay' ] = sess_dict ['n_delay' ]
365392
393+ for bias in [50 , 20 , 80 ]: # TODO check with someone if this is the way to do it
394+ sess_dict [f'combined_bias_{ bias } ' ] = sess_dict [f'bias_{ bias } ' ]
395+ sess_dict [f'combined_thres_{ bias } ' ] = sess_dict [f'thres_{ bias } ' ]
396+ sess_dict [f'combined_lapsehigh_{ bias } ' ] = sess_dict [f'lapsehigh_{ bias } ' ]
397+ sess_dict [f'combined_lapselow_{ bias } ' ] = sess_dict [f'lapselow_{ bias } ' ]
398+
366399 return sess_dicts
367400
368401
@@ -384,7 +417,7 @@ def check_up_to_date(subj_path, df):
384417 df_session = pd .concat ([df_session , pd .DataFrame ({'date' : date , 'session_path' : str (sess )}, index = [0 ])],
385418 ignore_index = True )
386419
387- if df is None :
420+ if df is None or 'combined_thres_50' not in df . columns :
388421 return df_session
389422 else :
390423 # recorded_session_paths = df['session_path'].values
@@ -399,14 +432,18 @@ def plot_trial_count_and_session_duration(df, subject):
399432
400433 y1 = {'column' : 'combined_n_trials' ,
401434 'title' : 'Trial counts' ,
402- 'lim' : None }
435+ 'lim' : None ,
436+ 'color' : 'k' ,
437+ 'join' : True }
403438
404439 y2 = {'column' : 'combined_sess_duration' ,
405440 'title' : 'Session duration (mins)' ,
406441 'lim' : None ,
407- 'log' : False }
442+ 'color' : 'r' ,
443+ 'log' : False ,
444+ 'join' : True }
408445
409- ax = plot_over_days (df , y1 , y2 , subject )
446+ ax = plot_over_days (df , subject , y1 , y2 )
410447
411448 return ax
412449
@@ -416,40 +453,152 @@ def plot_performance_easy_median_reaction_time(df, subject):
416453
417454 y1 = {'column' : 'combined_performance_easy' ,
418455 'title' : 'Performance on easy trials' ,
419- 'lim' : [0 , 1.05 ]}
456+ 'lim' : [0 , 1.05 ],
457+ 'color' : 'k' ,
458+ 'join' : True }
420459
421460 y2 = {'column' : 'combined_reaction_time' ,
422461 'title' : 'Median reaction time (s)' ,
423462 'lim' : [0.1 , np .nanmax ([10 , np .nanmax (df .combined_reaction_time .values )])],
424- 'log' : True }
425- ax = plot_over_days (df , y1 , y2 , subject )
463+ 'color' : 'r' ,
464+ 'log' : True ,
465+ 'join' : True }
466+ ax = plot_over_days (df , subject , y1 , y2 )
426467
427468 return ax
428469
429470
430- def plot_over_days (df , y1 , y2 , subject , ax = None ):
471+ def plot_fit_params (df , subject ):
472+ fig , axs = plt .subplots (2 , 2 , figsize = (12 , 6 ))
473+ axs = axs .ravel ()
474+
475+ df = df .drop_duplicates ('date' ).reset_index (drop = True )
476+
477+ cmap = sns .diverging_palette (20 , 220 , n = 3 , center = "dark" )
478+
479+ y50 = {'column' : 'combined_bias_50' ,
480+ 'title' : 'Bias' ,
481+ 'lim' : [- 100 , 100 ],
482+ 'color' : cmap [1 ],
483+ 'join' : False }
484+
485+ y80 = {'column' : 'combined_bias_80' ,
486+ 'title' : 'Bias' ,
487+ 'lim' : [- 100 , 100 ],
488+ 'color' : cmap [2 ],
489+ 'join' : False }
490+
491+ y20 = {'column' : 'combined_bias_20' ,
492+ 'title' : 'Bias' ,
493+ 'lim' : [- 100 , 100 ],
494+ 'color' : cmap [0 ],
495+ 'join' : False }
496+
497+ plot_over_days (df , subject , y50 , ax = axs [0 ], legend = False , title = False )
498+ plot_over_days (df , subject , y80 , ax = axs [0 ], legend = False , title = False )
499+ plot_over_days (df , subject , y20 , ax = axs [0 ], legend = False , title = False )
500+ axs [0 ].axhline (16 , linewidth = 2 , linestyle = '--' , color = 'k' )
501+ axs [0 ].axhline (- 16 , linewidth = 2 , linestyle = '--' , color = 'k' )
502+
503+ y50 ['column' ] = 'combined_thres_50'
504+ y50 ['title' ] = 'Threshold'
505+ y50 ['lim' ] = [0 , 100 ]
506+ y80 ['column' ] = 'combined_thres_20'
507+ y80 ['title' ] = 'Threshold'
508+ y20 ['lim' ] = [0 , 100 ]
509+ y20 ['column' ] = 'combined_thres_80'
510+ y20 ['title' ] = 'Threshold'
511+ y80 ['lim' ] = [0 , 100 ]
512+
513+ plot_over_days (df , subject , y50 , ax = axs [1 ], legend = False , title = False )
514+ plot_over_days (df , subject , y80 , ax = axs [1 ], legend = False , title = False )
515+ plot_over_days (df , subject , y20 , ax = axs [1 ], legend = False , title = False )
516+ axs [1 ].axhline (19 , linewidth = 2 , linestyle = '--' , color = 'k' )
517+
518+ y50 ['column' ] = 'combined_lapselow_50'
519+ y50 ['title' ] = 'Lapse Low'
520+ y50 ['lim' ] = [0 , 1 ]
521+ y80 ['column' ] = 'combined_lapselow_20'
522+ y80 ['title' ] = 'Lapse Low'
523+ y80 ['lim' ] = [0 , 1 ]
524+ y20 ['column' ] = 'combined_lapselow_80'
525+ y20 ['title' ] = 'Lapse Low'
526+ y20 ['lim' ] = [0 , 1 ]
527+
528+ plot_over_days (df , subject , y50 , ax = axs [2 ], legend = False , title = False )
529+ plot_over_days (df , subject , y80 , ax = axs [2 ], legend = False , title = False )
530+ plot_over_days (df , subject , y20 , ax = axs [2 ], legend = False , title = False )
531+ axs [2 ].axhline (0.2 , linewidth = 2 , linestyle = '--' , color = 'k' )
532+
533+ y50 ['column' ] = 'combined_lapsehigh_50'
534+ y50 ['title' ] = 'Lapse High'
535+ y50 ['lim' ] = [0 , 1 ]
536+ y80 ['column' ] = 'combined_lapsehigh_20'
537+ y80 ['title' ] = 'Lapse High'
538+ y80 ['lim' ] = [0 , 1 ]
539+ y20 ['column' ] = 'combined_lapsehigh_80'
540+ y20 ['title' ] = 'Lapse High'
541+ y20 ['lim' ] = [0 , 1 ]
542+
543+ plot_over_days (df , subject , y50 , ax = axs [3 ], legend = False , title = False , training_lines = True )
544+ plot_over_days (df , subject , y80 , ax = axs [3 ], legend = False , title = False , training_lines = False )
545+ plot_over_days (df , subject , y20 , ax = axs [3 ], legend = False , title = False , training_lines = False )
546+ axs [3 ].axhline (0.2 , linewidth = 2 , linestyle = '--' , color = 'k' )
547+
548+ fig .suptitle (f'{ subject } { df .iloc [- 1 ]["date" ]} : { df .iloc [- 1 ]["training_status" ]} ' )
549+ lines , labels = axs [3 ].get_legend_handles_labels ()
550+ fig .legend (lines , labels , loc = 'upper center' , bbox_to_anchor = (0.5 , 0.1 ), fancybox = True , shadow = True , ncol = 5 )
551+
552+ legend_elements = [Line2D ([0 ], [0 ], marker = 'o' , color = 'w' , label = 'p=0.5' , markerfacecolor = cmap [1 ], markersize = 8 ),
553+ Line2D ([0 ], [0 ], marker = 'o' , color = 'w' , label = 'p=0.2' , markerfacecolor = cmap [0 ], markersize = 8 ),
554+ Line2D ([0 ], [0 ], marker = 'o' , color = 'w' , label = 'p=0.8' , markerfacecolor = cmap [2 ], markersize = 8 )]
555+ legend2 = plt .legend (handles = legend_elements , loc = 'upper right' , bbox_to_anchor = (1.1 , - 0.2 ), fancybox = True , shadow = True )
556+ fig .add_artist (legend2 )
557+
558+ return axs
559+
560+
561+ def plot_psychometric_curve (df , subject , one ):
562+ df = df .drop_duplicates ('date' ).reset_index (drop = True )
563+ sess_path = Path (df .iloc [- 1 ]["session_path" ])
564+ trials = load_trials (sess_path , one )
565+
566+ fig , ax1 = plt .subplots (figsize = (8 , 6 ))
567+
568+ training .plot_psychometric (trials , ax = ax1 , title = f'{ subject } { df .iloc [- 1 ]["date" ]} : { df .iloc [- 1 ]["training_status" ]} ' )
569+
570+ return ax1
571+
572+
573+ def plot_over_days (df , subject , y1 , y2 = None , ax = None , legend = True , title = True , training_lines = True ):
431574
432575 if ax is None :
433576 fig , ax1 = plt .subplots (figsize = (12 , 6 ))
434577 else :
435578 ax1 = ax
436579
437- ax2 = ax1 .twinx ()
438-
439580 dates = [datetime .strptime (dat , '%Y-%m-%d' ) for dat in df ['date' ]]
440- ax1 .plot (dates , df [y1 ['column' ]], 'k' )
441- ax1 .scatter (dates , df [y1 ['column' ]], c = 'k' )
581+ if y1 ['join' ]:
582+ ax1 .plot (dates , df [y1 ['column' ]], color = y1 ['color' ])
583+ ax1 .scatter (dates , df [y1 ['column' ]], color = y1 ['color' ])
442584 ax1 .set_ylabel (y1 ['title' ])
443585 ax1 .set_ylim (y1 ['lim' ])
444586
445- ax2 .plot (dates , df [y2 ['column' ]], 'r' )
446- ax2 .scatter (dates , df [y2 ['column' ]], c = 'r' )
447- ax2 .set_ylabel (y2 ['title' ])
448- ax2 .yaxis .label .set_color ('r' )
449- ax2 .tick_params (axis = 'y' , colors = 'r' )
450- ax2 .set_ylim (y2 ['lim' ])
451- if y2 ['log' ]:
452- ax2 .set_yscale ('log' )
587+ if y2 is not None :
588+ ax2 = ax1 .twinx ()
589+ if y2 ['join' ]:
590+ ax2 .plot (dates , df [y2 ['column' ]], color = y2 ['color' ])
591+ ax2 .scatter (dates , df [y2 ['column' ]], color = y2 ['color' ])
592+ ax2 .set_ylabel (y2 ['title' ])
593+ ax2 .yaxis .label .set_color (y2 ['color' ])
594+ ax2 .tick_params (axis = 'y' , colors = y2 ['color' ])
595+ ax2 .set_ylim (y2 ['lim' ])
596+ if y2 ['log' ]:
597+ ax2 .set_yscale ('log' )
598+
599+ ax2 .spines ['right' ].set_visible (False )
600+ ax2 .spines ['top' ].set_visible (False )
601+ ax2 .spines ['left' ].set_visible (False )
453602
454603 month_format = mdates .DateFormatter ('%b %Y' )
455604 month_locator = mdates .MonthLocator ()
@@ -462,20 +611,20 @@ def plot_over_days(df, y1, y2, subject, ax=None):
462611 ax1 .spines ['left' ].set_visible (False )
463612 ax1 .spines ['right' ].set_visible (False )
464613 ax1 .spines ['top' ].set_visible (False )
465- ax2 .spines ['right' ].set_visible (False )
466- ax2 .spines ['top' ].set_visible (False )
467- ax2 .spines ['left' ].set_visible (False )
468614
469- ax1 = add_training_lines (df , ax1 )
615+ if training_lines :
616+ ax1 = add_training_lines (df , ax1 )
470617
471- ax1 .set_title (f'{ subject } { df .iloc [- 1 ]["date" ]} : { df .iloc [- 1 ]["training_status" ]} ' )
472- box = ax1 .get_position ()
473- ax1 .set_position ([box .x0 , box .y0 + box .height * 0.1 ,
474- box .width , box .height * 0.9 ])
618+ if title :
619+ ax1 .set_title (f'{ subject } { df .iloc [- 1 ]["date" ]} : { df .iloc [- 1 ]["training_status" ]} ' )
475620
476621 # Put a legend below current axis
477- ax1 .legend (loc = 'upper center' , bbox_to_anchor = (0.5 , - 0.1 ),
478- fancybox = True , shadow = True , ncol = 5 )
622+ box = ax1 .get_position ()
623+ ax1 .set_position ([box .x0 , box .y0 + box .height * 0.1 ,
624+ box .width , box .height * 0.9 ])
625+ if legend :
626+ ax1 .legend (loc = 'upper center' , bbox_to_anchor = (0.5 , - 0.1 ),
627+ fancybox = True , shadow = True , ncol = 5 )
479628
480629 return ax1
481630
@@ -554,6 +703,8 @@ def make_plots(session_path, one, df=None, save=False, upload=False):
554703 ax1 = plot_trial_count_and_session_duration (df , subject )
555704 ax2 = plot_performance_easy_median_reaction_time (df , subject )
556705 ax3 = plot_heatmap_performance_over_days (df , subject )
706+ ax4 = plot_fit_params (df , subject )
707+ ax5 = plot_psychometric_curve (df , subject , one )
557708
558709 outputs = []
559710 if save :
@@ -570,6 +721,14 @@ def make_plots(session_path, one, df=None, save=False, upload=False):
570721 outputs .append (save_name )
571722 ax3 .get_figure ().savefig (save_name , bbox_inches = 'tight' )
572723
724+ save_name = save_path .joinpath ('subj_psychometric_fit_params.png' )
725+ outputs .append (save_name )
726+ ax4 [0 ].get_figure ().savefig (save_name , bbox_inches = 'tight' )
727+
728+ save_name = save_path .joinpath ('subj_psychometric_curve.png' )
729+ outputs .append (save_name )
730+ ax5 .get_figure ().savefig (save_name , bbox_inches = 'tight' )
731+
573732 if upload :
574733 subj = one .alyx .rest ('subjects' , 'list' , nickname = subject )[0 ]
575734 snp = ReportSnapshot (session_path , subj ['id' ], content_type = 'subject' , one = one )
0 commit comments