Skip to content

Commit ce241aa

Browse files
committed
Merge remote-tracking branch 'origin/psycho_plots' into hotfix/2.13.4
2 parents 959edc3 + dedfb79 commit ce241aa

File tree

1 file changed

+190
-31
lines changed

1 file changed

+190
-31
lines changed

ibllib/pipes/training_status.py

Lines changed: 190 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from pathlib import Path
1717
import matplotlib.pyplot as plt
1818
import matplotlib.dates as mdates
19+
from matplotlib.lines import Line2D
1920
from datetime import datetime
21+
import seaborn as sns
2022

2123
one = ONE()
2224

@@ -309,12 +311,26 @@ def get_training_info_for_session(session_paths, one):
309311
continue
310312

311313
sess_dict['performance'], sess_dict['contrasts'], _ = training.compute_performance(trials, prob_right=True)
314+
if sess_dict['task_protocol'] == 'training':
315+
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \
316+
training.compute_psychometric(trials)
317+
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \
318+
(np.nan, np.nan, np.nan, np.nan)
319+
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \
320+
(np.nan, np.nan, np.nan, np.nan)
321+
else:
322+
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \
323+
training.compute_psychometric(trials, block=0.5)
324+
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \
325+
training.compute_psychometric(trials, block=0.2)
326+
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \
327+
training.compute_psychometric(trials, block=0.8)
328+
312329
sess_dict['performance_easy'] = training.compute_performance_easy(trials)
313330
sess_dict['reaction_time'] = training.compute_median_reaction_time(trials)
314331
sess_dict['n_trials'] = training.compute_n_trials(trials)
315332
sess_dict['sess_duration'], sess_dict['n_delay'], sess_dict['location'] = \
316333
compute_session_duration_delay_location(session_path)
317-
sess_dict['task_protocol'] = get_session_extractor_type(session_path)
318334
sess_dict['training_status'] = 'not_computed'
319335

320336
sess_dicts.append(sess_dict)
@@ -328,6 +344,11 @@ def get_training_info_for_session(session_paths, one):
328344
print(f'{len(sess_dicts)} sessions being combined for date {sess_dicts[0]["date"]}')
329345
combined_trials = load_combined_trials(session_paths, one)
330346
performance, contrasts, _ = training.compute_performance(combined_trials, prob_right=True)
347+
psychs = {}
348+
psychs['50'] = training.compute_psychometric(trials, block=0.5)
349+
psychs['20'] = training.compute_psychometric(trials, block=0.2)
350+
psychs['80'] = training.compute_psychometric(trials, block=0.8)
351+
331352
performance_easy = training.compute_performance_easy(combined_trials)
332353
reaction_time = training.compute_median_reaction_time(combined_trials)
333354
n_trials = training.compute_n_trials(combined_trials)
@@ -344,6 +365,12 @@ def get_training_info_for_session(session_paths, one):
344365
sess_dict['combined_sess_duration'] = sess_duration
345366
sess_dict['combined_n_delay'] = n_delay
346367

368+
for bias in [50, 20, 80]:
369+
sess_dict[f'combined_bias_{bias}'] = psychs[f'{bias}'][0]
370+
sess_dict[f'combined_thres_{bias}'] = psychs[f'{bias}'][1]
371+
sess_dict[f'combined_lapsehigh_{bias}'] = psychs[f'{bias}'][2]
372+
sess_dict[f'combined_lapselow_{bias}'] = psychs[f'{bias}'][3]
373+
347374
# Case where two sessions on same day with different number of contrasts! Oh boy
348375
if sess_dict['combined_performance'].size != sess_dict['performance'].size:
349376
sess_dict['performance'] = \
@@ -363,6 +390,12 @@ def get_training_info_for_session(session_paths, one):
363390
sess_dict['combined_sess_duration'] = sess_dict['sess_duration']
364391
sess_dict['combined_n_delay'] = sess_dict['n_delay']
365392

393+
for bias in [50, 20, 80]: # TODO check with someone if this is the way to do it
394+
sess_dict[f'combined_bias_{bias}'] = sess_dict[f'bias_{bias}']
395+
sess_dict[f'combined_thres_{bias}'] = sess_dict[f'thres_{bias}']
396+
sess_dict[f'combined_lapsehigh_{bias}'] = sess_dict[f'lapsehigh_{bias}']
397+
sess_dict[f'combined_lapselow_{bias}'] = sess_dict[f'lapselow_{bias}']
398+
366399
return sess_dicts
367400

368401

@@ -384,7 +417,7 @@ def check_up_to_date(subj_path, df):
384417
df_session = pd.concat([df_session, pd.DataFrame({'date': date, 'session_path': str(sess)}, index=[0])],
385418
ignore_index=True)
386419

387-
if df is None:
420+
if df is None or 'combined_thres_50' not in df.columns:
388421
return df_session
389422
else:
390423
# recorded_session_paths = df['session_path'].values
@@ -399,14 +432,18 @@ def plot_trial_count_and_session_duration(df, subject):
399432

400433
y1 = {'column': 'combined_n_trials',
401434
'title': 'Trial counts',
402-
'lim': None}
435+
'lim': None,
436+
'color': 'k',
437+
'join': True}
403438

404439
y2 = {'column': 'combined_sess_duration',
405440
'title': 'Session duration (mins)',
406441
'lim': None,
407-
'log': False}
442+
'color': 'r',
443+
'log': False,
444+
'join': True}
408445

409-
ax = plot_over_days(df, y1, y2, subject)
446+
ax = plot_over_days(df, subject, y1, y2)
410447

411448
return ax
412449

@@ -416,40 +453,152 @@ def plot_performance_easy_median_reaction_time(df, subject):
416453

417454
y1 = {'column': 'combined_performance_easy',
418455
'title': 'Performance on easy trials',
419-
'lim': [0, 1.05]}
456+
'lim': [0, 1.05],
457+
'color': 'k',
458+
'join': True}
420459

421460
y2 = {'column': 'combined_reaction_time',
422461
'title': 'Median reaction time (s)',
423462
'lim': [0.1, np.nanmax([10, np.nanmax(df.combined_reaction_time.values)])],
424-
'log': True}
425-
ax = plot_over_days(df, y1, y2, subject)
463+
'color': 'r',
464+
'log': True,
465+
'join': True}
466+
ax = plot_over_days(df, subject, y1, y2)
426467

427468
return ax
428469

429470

430-
def plot_over_days(df, y1, y2, subject, ax=None):
471+
def plot_fit_params(df, subject):
472+
fig, axs = plt.subplots(2, 2, figsize=(12, 6))
473+
axs = axs.ravel()
474+
475+
df = df.drop_duplicates('date').reset_index(drop=True)
476+
477+
cmap = sns.diverging_palette(20, 220, n=3, center="dark")
478+
479+
y50 = {'column': 'combined_bias_50',
480+
'title': 'Bias',
481+
'lim': [-100, 100],
482+
'color': cmap[1],
483+
'join': False}
484+
485+
y80 = {'column': 'combined_bias_80',
486+
'title': 'Bias',
487+
'lim': [-100, 100],
488+
'color': cmap[2],
489+
'join': False}
490+
491+
y20 = {'column': 'combined_bias_20',
492+
'title': 'Bias',
493+
'lim': [-100, 100],
494+
'color': cmap[0],
495+
'join': False}
496+
497+
plot_over_days(df, subject, y50, ax=axs[0], legend=False, title=False)
498+
plot_over_days(df, subject, y80, ax=axs[0], legend=False, title=False)
499+
plot_over_days(df, subject, y20, ax=axs[0], legend=False, title=False)
500+
axs[0].axhline(16, linewidth=2, linestyle='--', color='k')
501+
axs[0].axhline(-16, linewidth=2, linestyle='--', color='k')
502+
503+
y50['column'] = 'combined_thres_50'
504+
y50['title'] = 'Threshold'
505+
y50['lim'] = [0, 100]
506+
y80['column'] = 'combined_thres_20'
507+
y80['title'] = 'Threshold'
508+
y20['lim'] = [0, 100]
509+
y20['column'] = 'combined_thres_80'
510+
y20['title'] = 'Threshold'
511+
y80['lim'] = [0, 100]
512+
513+
plot_over_days(df, subject, y50, ax=axs[1], legend=False, title=False)
514+
plot_over_days(df, subject, y80, ax=axs[1], legend=False, title=False)
515+
plot_over_days(df, subject, y20, ax=axs[1], legend=False, title=False)
516+
axs[1].axhline(19, linewidth=2, linestyle='--', color='k')
517+
518+
y50['column'] = 'combined_lapselow_50'
519+
y50['title'] = 'Lapse Low'
520+
y50['lim'] = [0, 1]
521+
y80['column'] = 'combined_lapselow_20'
522+
y80['title'] = 'Lapse Low'
523+
y80['lim'] = [0, 1]
524+
y20['column'] = 'combined_lapselow_80'
525+
y20['title'] = 'Lapse Low'
526+
y20['lim'] = [0, 1]
527+
528+
plot_over_days(df, subject, y50, ax=axs[2], legend=False, title=False)
529+
plot_over_days(df, subject, y80, ax=axs[2], legend=False, title=False)
530+
plot_over_days(df, subject, y20, ax=axs[2], legend=False, title=False)
531+
axs[2].axhline(0.2, linewidth=2, linestyle='--', color='k')
532+
533+
y50['column'] = 'combined_lapsehigh_50'
534+
y50['title'] = 'Lapse High'
535+
y50['lim'] = [0, 1]
536+
y80['column'] = 'combined_lapsehigh_20'
537+
y80['title'] = 'Lapse High'
538+
y80['lim'] = [0, 1]
539+
y20['column'] = 'combined_lapsehigh_80'
540+
y20['title'] = 'Lapse High'
541+
y20['lim'] = [0, 1]
542+
543+
plot_over_days(df, subject, y50, ax=axs[3], legend=False, title=False, training_lines=True)
544+
plot_over_days(df, subject, y80, ax=axs[3], legend=False, title=False, training_lines=False)
545+
plot_over_days(df, subject, y20, ax=axs[3], legend=False, title=False, training_lines=False)
546+
axs[3].axhline(0.2, linewidth=2, linestyle='--', color='k')
547+
548+
fig.suptitle(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}')
549+
lines, labels = axs[3].get_legend_handles_labels()
550+
fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), fancybox=True, shadow=True, ncol=5)
551+
552+
legend_elements = [Line2D([0], [0], marker='o', color='w', label='p=0.5', markerfacecolor=cmap[1], markersize=8),
553+
Line2D([0], [0], marker='o', color='w', label='p=0.2', markerfacecolor=cmap[0], markersize=8),
554+
Line2D([0], [0], marker='o', color='w', label='p=0.8', markerfacecolor=cmap[2], markersize=8)]
555+
legend2 = plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.1, -0.2), fancybox=True, shadow=True)
556+
fig.add_artist(legend2)
557+
558+
return axs
559+
560+
561+
def plot_psychometric_curve(df, subject, one):
562+
df = df.drop_duplicates('date').reset_index(drop=True)
563+
sess_path = Path(df.iloc[-1]["session_path"])
564+
trials = load_trials(sess_path, one)
565+
566+
fig, ax1 = plt.subplots(figsize=(8, 6))
567+
568+
training.plot_psychometric(trials, ax=ax1, title=f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}')
569+
570+
return ax1
571+
572+
573+
def plot_over_days(df, subject, y1, y2=None, ax=None, legend=True, title=True, training_lines=True):
431574

432575
if ax is None:
433576
fig, ax1 = plt.subplots(figsize=(12, 6))
434577
else:
435578
ax1 = ax
436579

437-
ax2 = ax1.twinx()
438-
439580
dates = [datetime.strptime(dat, '%Y-%m-%d') for dat in df['date']]
440-
ax1.plot(dates, df[y1['column']], 'k')
441-
ax1.scatter(dates, df[y1['column']], c='k')
581+
if y1['join']:
582+
ax1.plot(dates, df[y1['column']], color=y1['color'])
583+
ax1.scatter(dates, df[y1['column']], color=y1['color'])
442584
ax1.set_ylabel(y1['title'])
443585
ax1.set_ylim(y1['lim'])
444586

445-
ax2.plot(dates, df[y2['column']], 'r')
446-
ax2.scatter(dates, df[y2['column']], c='r')
447-
ax2.set_ylabel(y2['title'])
448-
ax2.yaxis.label.set_color('r')
449-
ax2.tick_params(axis='y', colors='r')
450-
ax2.set_ylim(y2['lim'])
451-
if y2['log']:
452-
ax2.set_yscale('log')
587+
if y2 is not None:
588+
ax2 = ax1.twinx()
589+
if y2['join']:
590+
ax2.plot(dates, df[y2['column']], color=y2['color'])
591+
ax2.scatter(dates, df[y2['column']], color=y2['color'])
592+
ax2.set_ylabel(y2['title'])
593+
ax2.yaxis.label.set_color(y2['color'])
594+
ax2.tick_params(axis='y', colors=y2['color'])
595+
ax2.set_ylim(y2['lim'])
596+
if y2['log']:
597+
ax2.set_yscale('log')
598+
599+
ax2.spines['right'].set_visible(False)
600+
ax2.spines['top'].set_visible(False)
601+
ax2.spines['left'].set_visible(False)
453602

454603
month_format = mdates.DateFormatter('%b %Y')
455604
month_locator = mdates.MonthLocator()
@@ -462,20 +611,20 @@ def plot_over_days(df, y1, y2, subject, ax=None):
462611
ax1.spines['left'].set_visible(False)
463612
ax1.spines['right'].set_visible(False)
464613
ax1.spines['top'].set_visible(False)
465-
ax2.spines['right'].set_visible(False)
466-
ax2.spines['top'].set_visible(False)
467-
ax2.spines['left'].set_visible(False)
468614

469-
ax1 = add_training_lines(df, ax1)
615+
if training_lines:
616+
ax1 = add_training_lines(df, ax1)
470617

471-
ax1.set_title(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}')
472-
box = ax1.get_position()
473-
ax1.set_position([box.x0, box.y0 + box.height * 0.1,
474-
box.width, box.height * 0.9])
618+
if title:
619+
ax1.set_title(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}')
475620

476621
# Put a legend below current axis
477-
ax1.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1),
478-
fancybox=True, shadow=True, ncol=5)
622+
box = ax1.get_position()
623+
ax1.set_position([box.x0, box.y0 + box.height * 0.1,
624+
box.width, box.height * 0.9])
625+
if legend:
626+
ax1.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1),
627+
fancybox=True, shadow=True, ncol=5)
479628

480629
return ax1
481630

@@ -554,6 +703,8 @@ def make_plots(session_path, one, df=None, save=False, upload=False):
554703
ax1 = plot_trial_count_and_session_duration(df, subject)
555704
ax2 = plot_performance_easy_median_reaction_time(df, subject)
556705
ax3 = plot_heatmap_performance_over_days(df, subject)
706+
ax4 = plot_fit_params(df, subject)
707+
ax5 = plot_psychometric_curve(df, subject, one)
557708

558709
outputs = []
559710
if save:
@@ -570,6 +721,14 @@ def make_plots(session_path, one, df=None, save=False, upload=False):
570721
outputs.append(save_name)
571722
ax3.get_figure().savefig(save_name, bbox_inches='tight')
572723

724+
save_name = save_path.joinpath('subj_psychometric_fit_params.png')
725+
outputs.append(save_name)
726+
ax4[0].get_figure().savefig(save_name, bbox_inches='tight')
727+
728+
save_name = save_path.joinpath('subj_psychometric_curve.png')
729+
outputs.append(save_name)
730+
ax5.get_figure().savefig(save_name, bbox_inches='tight')
731+
573732
if upload:
574733
subj = one.alyx.rest('subjects', 'list', nickname=subject)[0]
575734
snp = ReportSnapshot(session_path, subj['id'], content_type='subject', one=one)

0 commit comments

Comments
 (0)