1- from one .api import ONE
21import one .alf .io as alfio
32from one .alf .spec import is_session_path
43from one .alf .exceptions import ALFObjectNotFound
1615from pathlib import Path
1716import matplotlib .pyplot as plt
1817import matplotlib .dates as mdates
18+ from matplotlib .lines import Line2D
1919from datetime import datetime
20+ import seaborn as sns
2021
21- one = ONE ()
2222
2323TRAINING_STATUS = {'not_computed' : (- 2 , (0 , 0 , 0 , 0 )),
2424 'habituation' : (- 1 , (0 , 0 , 0 , 0 )),
@@ -301,6 +301,12 @@ def get_training_info_for_session(session_paths, one):
301301 sess_dict ['n_delay' ] = np .nan
302302 sess_dict ['location' ] = np .nan
303303 sess_dict ['training_status' ] = 'habituation'
304+ sess_dict ['bias_50' ], sess_dict ['thres_50' ], sess_dict ['lapsehigh_50' ], sess_dict ['lapselow_50' ] = \
305+ (np .nan , np .nan , np .nan , np .nan )
306+ sess_dict ['bias_20' ], sess_dict ['thres_20' ], sess_dict ['lapsehigh_20' ], sess_dict ['lapselow_20' ] = \
307+ (np .nan , np .nan , np .nan , np .nan )
308+ sess_dict ['bias_80' ], sess_dict ['thres_80' ], sess_dict ['lapsehigh_80' ], sess_dict ['lapselow_80' ] = \
309+ (np .nan , np .nan , np .nan , np .nan )
304310
305311 else :
306312 # if we can't compute trials then we need to pass
@@ -309,12 +315,26 @@ def get_training_info_for_session(session_paths, one):
309315 continue
310316
311317 sess_dict ['performance' ], sess_dict ['contrasts' ], _ = training .compute_performance (trials , prob_right = True )
318+ if sess_dict ['task_protocol' ] == 'training' :
319+ sess_dict ['bias_50' ], sess_dict ['thres_50' ], sess_dict ['lapsehigh_50' ], sess_dict ['lapselow_50' ] = \
320+ training .compute_psychometric (trials )
321+ sess_dict ['bias_20' ], sess_dict ['thres_20' ], sess_dict ['lapsehigh_20' ], sess_dict ['lapselow_20' ] = \
322+ (np .nan , np .nan , np .nan , np .nan )
323+ sess_dict ['bias_80' ], sess_dict ['thres_80' ], sess_dict ['lapsehigh_80' ], sess_dict ['lapselow_80' ] = \
324+ (np .nan , np .nan , np .nan , np .nan )
325+ else :
326+ sess_dict ['bias_50' ], sess_dict ['thres_50' ], sess_dict ['lapsehigh_50' ], sess_dict ['lapselow_50' ] = \
327+ training .compute_psychometric (trials , block = 0.5 )
328+ sess_dict ['bias_20' ], sess_dict ['thres_20' ], sess_dict ['lapsehigh_20' ], sess_dict ['lapselow_20' ] = \
329+ training .compute_psychometric (trials , block = 0.2 )
330+ sess_dict ['bias_80' ], sess_dict ['thres_80' ], sess_dict ['lapsehigh_80' ], sess_dict ['lapselow_80' ] = \
331+ training .compute_psychometric (trials , block = 0.8 )
332+
312333 sess_dict ['performance_easy' ] = training .compute_performance_easy (trials )
313334 sess_dict ['reaction_time' ] = training .compute_median_reaction_time (trials )
314335 sess_dict ['n_trials' ] = training .compute_n_trials (trials )
315336 sess_dict ['sess_duration' ], sess_dict ['n_delay' ], sess_dict ['location' ] = \
316337 compute_session_duration_delay_location (session_path )
317- sess_dict ['task_protocol' ] = get_session_extractor_type (session_path )
318338 sess_dict ['training_status' ] = 'not_computed'
319339
320340 sess_dicts .append (sess_dict )
@@ -328,6 +348,11 @@ def get_training_info_for_session(session_paths, one):
328348 print (f'{ len (sess_dicts )} sessions being combined for date { sess_dicts [0 ]["date" ]} ' )
329349 combined_trials = load_combined_trials (session_paths , one )
330350 performance , contrasts , _ = training .compute_performance (combined_trials , prob_right = True )
351+ psychs = {}
352+ psychs ['50' ] = training .compute_psychometric (trials , block = 0.5 )
353+ psychs ['20' ] = training .compute_psychometric (trials , block = 0.2 )
354+ psychs ['80' ] = training .compute_psychometric (trials , block = 0.8 )
355+
331356 performance_easy = training .compute_performance_easy (combined_trials )
332357 reaction_time = training .compute_median_reaction_time (combined_trials )
333358 n_trials = training .compute_n_trials (combined_trials )
@@ -344,6 +369,12 @@ def get_training_info_for_session(session_paths, one):
344369 sess_dict ['combined_sess_duration' ] = sess_duration
345370 sess_dict ['combined_n_delay' ] = n_delay
346371
372+ for bias in [50 , 20 , 80 ]:
373+ sess_dict [f'combined_bias_{ bias } ' ] = psychs [f'{ bias } ' ][0 ]
374+ sess_dict [f'combined_thres_{ bias } ' ] = psychs [f'{ bias } ' ][1 ]
375+ sess_dict [f'combined_lapsehigh_{ bias } ' ] = psychs [f'{ bias } ' ][2 ]
376+ sess_dict [f'combined_lapselow_{ bias } ' ] = psychs [f'{ bias } ' ][3 ]
377+
347378 # Case where two sessions on same day with different number of contrasts! Oh boy
348379 if sess_dict ['combined_performance' ].size != sess_dict ['performance' ].size :
349380 sess_dict ['performance' ] = \
@@ -363,6 +394,12 @@ def get_training_info_for_session(session_paths, one):
363394 sess_dict ['combined_sess_duration' ] = sess_dict ['sess_duration' ]
364395 sess_dict ['combined_n_delay' ] = sess_dict ['n_delay' ]
365396
397+ for bias in [50 , 20 , 80 ]:
398+ sess_dict [f'combined_bias_{ bias } ' ] = sess_dict [f'bias_{ bias } ' ]
399+ sess_dict [f'combined_thres_{ bias } ' ] = sess_dict [f'thres_{ bias } ' ]
400+ sess_dict [f'combined_lapsehigh_{ bias } ' ] = sess_dict [f'lapsehigh_{ bias } ' ]
401+ sess_dict [f'combined_lapselow_{ bias } ' ] = sess_dict [f'lapselow_{ bias } ' ]
402+
366403 return sess_dicts
367404
368405
@@ -384,7 +421,7 @@ def check_up_to_date(subj_path, df):
384421 df_session = pd .concat ([df_session , pd .DataFrame ({'date' : date , 'session_path' : str (sess )}, index = [0 ])],
385422 ignore_index = True )
386423
387- if df is None :
424+ if df is None or 'combined_thres_50' not in df . columns :
388425 return df_session
389426 else :
390427 # recorded_session_paths = df['session_path'].values
@@ -399,14 +436,18 @@ def plot_trial_count_and_session_duration(df, subject):
399436
400437 y1 = {'column' : 'combined_n_trials' ,
401438 'title' : 'Trial counts' ,
402- 'lim' : None }
439+ 'lim' : None ,
440+ 'color' : 'k' ,
441+ 'join' : True }
403442
404443 y2 = {'column' : 'combined_sess_duration' ,
405444 'title' : 'Session duration (mins)' ,
406445 'lim' : None ,
407- 'log' : False }
446+ 'color' : 'r' ,
447+ 'log' : False ,
448+ 'join' : True }
408449
409- ax = plot_over_days (df , y1 , y2 , subject )
450+ ax = plot_over_days (df , subject , y1 , y2 )
410451
411452 return ax
412453
@@ -416,40 +457,152 @@ def plot_performance_easy_median_reaction_time(df, subject):
416457
417458 y1 = {'column' : 'combined_performance_easy' ,
418459 'title' : 'Performance on easy trials' ,
419- 'lim' : [0 , 1.05 ]}
460+ 'lim' : [0 , 1.05 ],
461+ 'color' : 'k' ,
462+ 'join' : True }
420463
421464 y2 = {'column' : 'combined_reaction_time' ,
422465 'title' : 'Median reaction time (s)' ,
423466 'lim' : [0.1 , np .nanmax ([10 , np .nanmax (df .combined_reaction_time .values )])],
424- 'log' : True }
425- ax = plot_over_days (df , y1 , y2 , subject )
467+ 'color' : 'r' ,
468+ 'log' : True ,
469+ 'join' : True }
470+ ax = plot_over_days (df , subject , y1 , y2 )
426471
427472 return ax
428473
429474
430- def plot_over_days (df , y1 , y2 , subject , ax = None ):
475+ def plot_fit_params (df , subject ):
476+ fig , axs = plt .subplots (2 , 2 , figsize = (12 , 6 ))
477+ axs = axs .ravel ()
478+
479+ df = df .drop_duplicates ('date' ).reset_index (drop = True )
480+
481+ cmap = sns .diverging_palette (20 , 220 , n = 3 , center = "dark" )
482+
483+ y50 = {'column' : 'combined_bias_50' ,
484+ 'title' : 'Bias' ,
485+ 'lim' : [- 100 , 100 ],
486+ 'color' : cmap [1 ],
487+ 'join' : False }
488+
489+ y80 = {'column' : 'combined_bias_80' ,
490+ 'title' : 'Bias' ,
491+ 'lim' : [- 100 , 100 ],
492+ 'color' : cmap [2 ],
493+ 'join' : False }
494+
495+ y20 = {'column' : 'combined_bias_20' ,
496+ 'title' : 'Bias' ,
497+ 'lim' : [- 100 , 100 ],
498+ 'color' : cmap [0 ],
499+ 'join' : False }
500+
501+ plot_over_days (df , subject , y50 , ax = axs [0 ], legend = False , title = False )
502+ plot_over_days (df , subject , y80 , ax = axs [0 ], legend = False , title = False )
503+ plot_over_days (df , subject , y20 , ax = axs [0 ], legend = False , title = False )
504+ axs [0 ].axhline (16 , linewidth = 2 , linestyle = '--' , color = 'k' )
505+ axs [0 ].axhline (- 16 , linewidth = 2 , linestyle = '--' , color = 'k' )
506+
507+ y50 ['column' ] = 'combined_thres_50'
508+ y50 ['title' ] = 'Threshold'
509+ y50 ['lim' ] = [0 , 100 ]
510+ y80 ['column' ] = 'combined_thres_20'
511+ y80 ['title' ] = 'Threshold'
512+ y20 ['lim' ] = [0 , 100 ]
513+ y20 ['column' ] = 'combined_thres_80'
514+ y20 ['title' ] = 'Threshold'
515+ y80 ['lim' ] = [0 , 100 ]
516+
517+ plot_over_days (df , subject , y50 , ax = axs [1 ], legend = False , title = False )
518+ plot_over_days (df , subject , y80 , ax = axs [1 ], legend = False , title = False )
519+ plot_over_days (df , subject , y20 , ax = axs [1 ], legend = False , title = False )
520+ axs [1 ].axhline (19 , linewidth = 2 , linestyle = '--' , color = 'k' )
521+
522+ y50 ['column' ] = 'combined_lapselow_50'
523+ y50 ['title' ] = 'Lapse Low'
524+ y50 ['lim' ] = [0 , 1 ]
525+ y80 ['column' ] = 'combined_lapselow_20'
526+ y80 ['title' ] = 'Lapse Low'
527+ y80 ['lim' ] = [0 , 1 ]
528+ y20 ['column' ] = 'combined_lapselow_80'
529+ y20 ['title' ] = 'Lapse Low'
530+ y20 ['lim' ] = [0 , 1 ]
531+
532+ plot_over_days (df , subject , y50 , ax = axs [2 ], legend = False , title = False )
533+ plot_over_days (df , subject , y80 , ax = axs [2 ], legend = False , title = False )
534+ plot_over_days (df , subject , y20 , ax = axs [2 ], legend = False , title = False )
535+ axs [2 ].axhline (0.2 , linewidth = 2 , linestyle = '--' , color = 'k' )
536+
537+ y50 ['column' ] = 'combined_lapsehigh_50'
538+ y50 ['title' ] = 'Lapse High'
539+ y50 ['lim' ] = [0 , 1 ]
540+ y80 ['column' ] = 'combined_lapsehigh_20'
541+ y80 ['title' ] = 'Lapse High'
542+ y80 ['lim' ] = [0 , 1 ]
543+ y20 ['column' ] = 'combined_lapsehigh_80'
544+ y20 ['title' ] = 'Lapse High'
545+ y20 ['lim' ] = [0 , 1 ]
546+
547+ plot_over_days (df , subject , y50 , ax = axs [3 ], legend = False , title = False , training_lines = True )
548+ plot_over_days (df , subject , y80 , ax = axs [3 ], legend = False , title = False , training_lines = False )
549+ plot_over_days (df , subject , y20 , ax = axs [3 ], legend = False , title = False , training_lines = False )
550+ axs [3 ].axhline (0.2 , linewidth = 2 , linestyle = '--' , color = 'k' )
551+
552+ fig .suptitle (f'{ subject } { df .iloc [- 1 ]["date" ]} : { df .iloc [- 1 ]["training_status" ]} ' )
553+ lines , labels = axs [3 ].get_legend_handles_labels ()
554+ fig .legend (lines , labels , loc = 'upper center' , bbox_to_anchor = (0.5 , 0.1 ), fancybox = True , shadow = True , ncol = 5 )
555+
556+ legend_elements = [Line2D ([0 ], [0 ], marker = 'o' , color = 'w' , label = 'p=0.5' , markerfacecolor = cmap [1 ], markersize = 8 ),
557+ Line2D ([0 ], [0 ], marker = 'o' , color = 'w' , label = 'p=0.2' , markerfacecolor = cmap [0 ], markersize = 8 ),
558+ Line2D ([0 ], [0 ], marker = 'o' , color = 'w' , label = 'p=0.8' , markerfacecolor = cmap [2 ], markersize = 8 )]
559+ legend2 = plt .legend (handles = legend_elements , loc = 'upper right' , bbox_to_anchor = (1.1 , - 0.2 ), fancybox = True , shadow = True )
560+ fig .add_artist (legend2 )
561+
562+ return axs
563+
564+
565+ def plot_psychometric_curve (df , subject , one ):
566+ df = df .drop_duplicates ('date' ).reset_index (drop = True )
567+ sess_path = Path (df .iloc [- 1 ]["session_path" ])
568+ trials = load_trials (sess_path , one )
569+
570+ fig , ax1 = plt .subplots (figsize = (8 , 6 ))
571+
572+ training .plot_psychometric (trials , ax = ax1 , title = f'{ subject } { df .iloc [- 1 ]["date" ]} : { df .iloc [- 1 ]["training_status" ]} ' )
573+
574+ return ax1
575+
576+
577+ def plot_over_days (df , subject , y1 , y2 = None , ax = None , legend = True , title = True , training_lines = True ):
431578
432579 if ax is None :
433580 fig , ax1 = plt .subplots (figsize = (12 , 6 ))
434581 else :
435582 ax1 = ax
436583
437- ax2 = ax1 .twinx ()
438-
439584 dates = [datetime .strptime (dat , '%Y-%m-%d' ) for dat in df ['date' ]]
440- ax1 .plot (dates , df [y1 ['column' ]], 'k' )
441- ax1 .scatter (dates , df [y1 ['column' ]], c = 'k' )
585+ if y1 ['join' ]:
586+ ax1 .plot (dates , df [y1 ['column' ]], color = y1 ['color' ])
587+ ax1 .scatter (dates , df [y1 ['column' ]], color = y1 ['color' ])
442588 ax1 .set_ylabel (y1 ['title' ])
443589 ax1 .set_ylim (y1 ['lim' ])
444590
445- ax2 .plot (dates , df [y2 ['column' ]], 'r' )
446- ax2 .scatter (dates , df [y2 ['column' ]], c = 'r' )
447- ax2 .set_ylabel (y2 ['title' ])
448- ax2 .yaxis .label .set_color ('r' )
449- ax2 .tick_params (axis = 'y' , colors = 'r' )
450- ax2 .set_ylim (y2 ['lim' ])
451- if y2 ['log' ]:
452- ax2 .set_yscale ('log' )
591+ if y2 is not None :
592+ ax2 = ax1 .twinx ()
593+ if y2 ['join' ]:
594+ ax2 .plot (dates , df [y2 ['column' ]], color = y2 ['color' ])
595+ ax2 .scatter (dates , df [y2 ['column' ]], color = y2 ['color' ])
596+ ax2 .set_ylabel (y2 ['title' ])
597+ ax2 .yaxis .label .set_color (y2 ['color' ])
598+ ax2 .tick_params (axis = 'y' , colors = y2 ['color' ])
599+ ax2 .set_ylim (y2 ['lim' ])
600+ if y2 ['log' ]:
601+ ax2 .set_yscale ('log' )
602+
603+ ax2 .spines ['right' ].set_visible (False )
604+ ax2 .spines ['top' ].set_visible (False )
605+ ax2 .spines ['left' ].set_visible (False )
453606
454607 month_format = mdates .DateFormatter ('%b %Y' )
455608 month_locator = mdates .MonthLocator ()
@@ -462,20 +615,20 @@ def plot_over_days(df, y1, y2, subject, ax=None):
462615 ax1 .spines ['left' ].set_visible (False )
463616 ax1 .spines ['right' ].set_visible (False )
464617 ax1 .spines ['top' ].set_visible (False )
465- ax2 .spines ['right' ].set_visible (False )
466- ax2 .spines ['top' ].set_visible (False )
467- ax2 .spines ['left' ].set_visible (False )
468618
469- ax1 = add_training_lines (df , ax1 )
619+ if training_lines :
620+ ax1 = add_training_lines (df , ax1 )
470621
471- ax1 .set_title (f'{ subject } { df .iloc [- 1 ]["date" ]} : { df .iloc [- 1 ]["training_status" ]} ' )
472- box = ax1 .get_position ()
473- ax1 .set_position ([box .x0 , box .y0 + box .height * 0.1 ,
474- box .width , box .height * 0.9 ])
622+ if title :
623+ ax1 .set_title (f'{ subject } { df .iloc [- 1 ]["date" ]} : { df .iloc [- 1 ]["training_status" ]} ' )
475624
476625 # Put a legend below current axis
477- ax1 .legend (loc = 'upper center' , bbox_to_anchor = (0.5 , - 0.1 ),
478- fancybox = True , shadow = True , ncol = 5 )
626+ box = ax1 .get_position ()
627+ ax1 .set_position ([box .x0 , box .y0 + box .height * 0.1 ,
628+ box .width , box .height * 0.9 ])
629+ if legend :
630+ ax1 .legend (loc = 'upper center' , bbox_to_anchor = (0.5 , - 0.1 ),
631+ fancybox = True , shadow = True , ncol = 5 )
479632
480633 return ax1
481634
@@ -554,6 +707,8 @@ def make_plots(session_path, one, df=None, save=False, upload=False):
554707 ax1 = plot_trial_count_and_session_duration (df , subject )
555708 ax2 = plot_performance_easy_median_reaction_time (df , subject )
556709 ax3 = plot_heatmap_performance_over_days (df , subject )
710+ ax4 = plot_fit_params (df , subject )
711+ ax5 = plot_psychometric_curve (df , subject , one )
557712
558713 outputs = []
559714 if save :
@@ -570,6 +725,14 @@ def make_plots(session_path, one, df=None, save=False, upload=False):
570725 outputs .append (save_name )
571726 ax3 .get_figure ().savefig (save_name , bbox_inches = 'tight' )
572727
728+ save_name = save_path .joinpath ('subj_psychometric_fit_params.png' )
729+ outputs .append (save_name )
730+ ax4 [0 ].get_figure ().savefig (save_name , bbox_inches = 'tight' )
731+
732+ save_name = save_path .joinpath ('subj_psychometric_curve.png' )
733+ outputs .append (save_name )
734+ ax5 .get_figure ().savefig (save_name , bbox_inches = 'tight' )
735+
573736 if upload :
574737 subj = one .alyx .rest ('subjects' , 'list' , nickname = subject )[0 ]
575738 snp = ReportSnapshot (session_path , subj ['id' ], content_type = 'subject' , one = one )
0 commit comments