Skip to content

Commit 91c1c72

Browse files
committed
Merge branch 'hotfix/2.13.4' into develop
2 parents e8cce2d + 70cccde commit 91c1c72

File tree

6 files changed

+210
-43
lines changed

6 files changed

+210
-43
lines changed

brainbox/io/spikeglx.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,10 @@ def read(self, nsel=slice(0, 10000), csel=slice(None), sync=True):
168168
"""
169169
overload the read function by downloading the necessary chunks
170170
"""
171-
first_chunk = np.maximum(0, np.searchsorted(self.chunks['chunk_bounds'], nsel.start + 0.01 * self.fs) - 1)
172-
last_chunk = np.maximum(0, np.searchsorted(self.chunks['chunk_bounds'], nsel.stop + 0.01 * self.fs) - 2)
171+
first_chunk = np.maximum(0, np.searchsorted(self.chunks['chunk_bounds'], nsel.start) - 1)
172+
last_chunk = np.maximum(0, np.searchsorted(self.chunks['chunk_bounds'], nsel.stop) - 1)
173173
n0 = self.chunks['chunk_bounds'][first_chunk]
174+
_logger.debug(f'Streamer: caching sample {n0}, (t={n0 / self.fs})')
174175
self.cache_folder.mkdir(exist_ok=True, parents=True)
175176
sr = self._download_raw_partial(first_chunk=first_chunk, last_chunk=last_chunk)
176177
data = sr[nsel.start - n0: nsel.stop - n0, csel]

ibllib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "2.13.3"
1+
__version__ = "2.13.4"
22
import warnings
33

44
from ibllib.misc import logger_config

ibllib/pipes/training_status.py

Lines changed: 196 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from one.api import ONE
21
import one.alf.io as alfio
32
from one.alf.spec import is_session_path
43
from one.alf.exceptions import ALFObjectNotFound
@@ -16,9 +15,10 @@
1615
from pathlib import Path
1716
import matplotlib.pyplot as plt
1817
import matplotlib.dates as mdates
18+
from matplotlib.lines import Line2D
1919
from datetime import datetime
20+
import seaborn as sns
2021

21-
one = ONE()
2222

2323
TRAINING_STATUS = {'not_computed': (-2, (0, 0, 0, 0)),
2424
'habituation': (-1, (0, 0, 0, 0)),
@@ -301,6 +301,12 @@ def get_training_info_for_session(session_paths, one):
301301
sess_dict['n_delay'] = np.nan
302302
sess_dict['location'] = np.nan
303303
sess_dict['training_status'] = 'habituation'
304+
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \
305+
(np.nan, np.nan, np.nan, np.nan)
306+
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \
307+
(np.nan, np.nan, np.nan, np.nan)
308+
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \
309+
(np.nan, np.nan, np.nan, np.nan)
304310

305311
else:
306312
# if we can't compute trials then we need to pass
@@ -309,12 +315,26 @@ def get_training_info_for_session(session_paths, one):
309315
continue
310316

311317
sess_dict['performance'], sess_dict['contrasts'], _ = training.compute_performance(trials, prob_right=True)
318+
if sess_dict['task_protocol'] == 'training':
319+
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \
320+
training.compute_psychometric(trials)
321+
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \
322+
(np.nan, np.nan, np.nan, np.nan)
323+
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \
324+
(np.nan, np.nan, np.nan, np.nan)
325+
else:
326+
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \
327+
training.compute_psychometric(trials, block=0.5)
328+
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \
329+
training.compute_psychometric(trials, block=0.2)
330+
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \
331+
training.compute_psychometric(trials, block=0.8)
332+
312333
sess_dict['performance_easy'] = training.compute_performance_easy(trials)
313334
sess_dict['reaction_time'] = training.compute_median_reaction_time(trials)
314335
sess_dict['n_trials'] = training.compute_n_trials(trials)
315336
sess_dict['sess_duration'], sess_dict['n_delay'], sess_dict['location'] = \
316337
compute_session_duration_delay_location(session_path)
317-
sess_dict['task_protocol'] = get_session_extractor_type(session_path)
318338
sess_dict['training_status'] = 'not_computed'
319339

320340
sess_dicts.append(sess_dict)
@@ -328,6 +348,11 @@ def get_training_info_for_session(session_paths, one):
328348
print(f'{len(sess_dicts)} sessions being combined for date {sess_dicts[0]["date"]}')
329349
combined_trials = load_combined_trials(session_paths, one)
330350
performance, contrasts, _ = training.compute_performance(combined_trials, prob_right=True)
351+
psychs = {}
352+
psychs['50'] = training.compute_psychometric(trials, block=0.5)
353+
psychs['20'] = training.compute_psychometric(trials, block=0.2)
354+
psychs['80'] = training.compute_psychometric(trials, block=0.8)
355+
331356
performance_easy = training.compute_performance_easy(combined_trials)
332357
reaction_time = training.compute_median_reaction_time(combined_trials)
333358
n_trials = training.compute_n_trials(combined_trials)
@@ -344,6 +369,12 @@ def get_training_info_for_session(session_paths, one):
344369
sess_dict['combined_sess_duration'] = sess_duration
345370
sess_dict['combined_n_delay'] = n_delay
346371

372+
for bias in [50, 20, 80]:
373+
sess_dict[f'combined_bias_{bias}'] = psychs[f'{bias}'][0]
374+
sess_dict[f'combined_thres_{bias}'] = psychs[f'{bias}'][1]
375+
sess_dict[f'combined_lapsehigh_{bias}'] = psychs[f'{bias}'][2]
376+
sess_dict[f'combined_lapselow_{bias}'] = psychs[f'{bias}'][3]
377+
347378
# Case where two sessions on same day with different number of contrasts! Oh boy
348379
if sess_dict['combined_performance'].size != sess_dict['performance'].size:
349380
sess_dict['performance'] = \
@@ -363,6 +394,12 @@ def get_training_info_for_session(session_paths, one):
363394
sess_dict['combined_sess_duration'] = sess_dict['sess_duration']
364395
sess_dict['combined_n_delay'] = sess_dict['n_delay']
365396

397+
for bias in [50, 20, 80]:
398+
sess_dict[f'combined_bias_{bias}'] = sess_dict[f'bias_{bias}']
399+
sess_dict[f'combined_thres_{bias}'] = sess_dict[f'thres_{bias}']
400+
sess_dict[f'combined_lapsehigh_{bias}'] = sess_dict[f'lapsehigh_{bias}']
401+
sess_dict[f'combined_lapselow_{bias}'] = sess_dict[f'lapselow_{bias}']
402+
366403
return sess_dicts
367404

368405

@@ -384,7 +421,7 @@ def check_up_to_date(subj_path, df):
384421
df_session = pd.concat([df_session, pd.DataFrame({'date': date, 'session_path': str(sess)}, index=[0])],
385422
ignore_index=True)
386423

387-
if df is None:
424+
if df is None or 'combined_thres_50' not in df.columns:
388425
return df_session
389426
else:
390427
# recorded_session_paths = df['session_path'].values
@@ -399,14 +436,18 @@ def plot_trial_count_and_session_duration(df, subject):
399436

400437
y1 = {'column': 'combined_n_trials',
401438
'title': 'Trial counts',
402-
'lim': None}
439+
'lim': None,
440+
'color': 'k',
441+
'join': True}
403442

404443
y2 = {'column': 'combined_sess_duration',
405444
'title': 'Session duration (mins)',
406445
'lim': None,
407-
'log': False}
446+
'color': 'r',
447+
'log': False,
448+
'join': True}
408449

409-
ax = plot_over_days(df, y1, y2, subject)
450+
ax = plot_over_days(df, subject, y1, y2)
410451

411452
return ax
412453

@@ -416,40 +457,152 @@ def plot_performance_easy_median_reaction_time(df, subject):
416457

417458
y1 = {'column': 'combined_performance_easy',
418459
'title': 'Performance on easy trials',
419-
'lim': [0, 1.05]}
460+
'lim': [0, 1.05],
461+
'color': 'k',
462+
'join': True}
420463

421464
y2 = {'column': 'combined_reaction_time',
422465
'title': 'Median reaction time (s)',
423466
'lim': [0.1, np.nanmax([10, np.nanmax(df.combined_reaction_time.values)])],
424-
'log': True}
425-
ax = plot_over_days(df, y1, y2, subject)
467+
'color': 'r',
468+
'log': True,
469+
'join': True}
470+
ax = plot_over_days(df, subject, y1, y2)
426471

427472
return ax
428473

429474

430-
def plot_over_days(df, y1, y2, subject, ax=None):
475+
def plot_fit_params(df, subject):
476+
fig, axs = plt.subplots(2, 2, figsize=(12, 6))
477+
axs = axs.ravel()
478+
479+
df = df.drop_duplicates('date').reset_index(drop=True)
480+
481+
cmap = sns.diverging_palette(20, 220, n=3, center="dark")
482+
483+
y50 = {'column': 'combined_bias_50',
484+
'title': 'Bias',
485+
'lim': [-100, 100],
486+
'color': cmap[1],
487+
'join': False}
488+
489+
y80 = {'column': 'combined_bias_80',
490+
'title': 'Bias',
491+
'lim': [-100, 100],
492+
'color': cmap[2],
493+
'join': False}
494+
495+
y20 = {'column': 'combined_bias_20',
496+
'title': 'Bias',
497+
'lim': [-100, 100],
498+
'color': cmap[0],
499+
'join': False}
500+
501+
plot_over_days(df, subject, y50, ax=axs[0], legend=False, title=False)
502+
plot_over_days(df, subject, y80, ax=axs[0], legend=False, title=False)
503+
plot_over_days(df, subject, y20, ax=axs[0], legend=False, title=False)
504+
axs[0].axhline(16, linewidth=2, linestyle='--', color='k')
505+
axs[0].axhline(-16, linewidth=2, linestyle='--', color='k')
506+
507+
y50['column'] = 'combined_thres_50'
508+
y50['title'] = 'Threshold'
509+
y50['lim'] = [0, 100]
510+
y80['column'] = 'combined_thres_20'
511+
y80['title'] = 'Threshold'
512+
y20['lim'] = [0, 100]
513+
y20['column'] = 'combined_thres_80'
514+
y20['title'] = 'Threshold'
515+
y80['lim'] = [0, 100]
516+
517+
plot_over_days(df, subject, y50, ax=axs[1], legend=False, title=False)
518+
plot_over_days(df, subject, y80, ax=axs[1], legend=False, title=False)
519+
plot_over_days(df, subject, y20, ax=axs[1], legend=False, title=False)
520+
axs[1].axhline(19, linewidth=2, linestyle='--', color='k')
521+
522+
y50['column'] = 'combined_lapselow_50'
523+
y50['title'] = 'Lapse Low'
524+
y50['lim'] = [0, 1]
525+
y80['column'] = 'combined_lapselow_20'
526+
y80['title'] = 'Lapse Low'
527+
y80['lim'] = [0, 1]
528+
y20['column'] = 'combined_lapselow_80'
529+
y20['title'] = 'Lapse Low'
530+
y20['lim'] = [0, 1]
531+
532+
plot_over_days(df, subject, y50, ax=axs[2], legend=False, title=False)
533+
plot_over_days(df, subject, y80, ax=axs[2], legend=False, title=False)
534+
plot_over_days(df, subject, y20, ax=axs[2], legend=False, title=False)
535+
axs[2].axhline(0.2, linewidth=2, linestyle='--', color='k')
536+
537+
y50['column'] = 'combined_lapsehigh_50'
538+
y50['title'] = 'Lapse High'
539+
y50['lim'] = [0, 1]
540+
y80['column'] = 'combined_lapsehigh_20'
541+
y80['title'] = 'Lapse High'
542+
y80['lim'] = [0, 1]
543+
y20['column'] = 'combined_lapsehigh_80'
544+
y20['title'] = 'Lapse High'
545+
y20['lim'] = [0, 1]
546+
547+
plot_over_days(df, subject, y50, ax=axs[3], legend=False, title=False, training_lines=True)
548+
plot_over_days(df, subject, y80, ax=axs[3], legend=False, title=False, training_lines=False)
549+
plot_over_days(df, subject, y20, ax=axs[3], legend=False, title=False, training_lines=False)
550+
axs[3].axhline(0.2, linewidth=2, linestyle='--', color='k')
551+
552+
fig.suptitle(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}')
553+
lines, labels = axs[3].get_legend_handles_labels()
554+
fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), fancybox=True, shadow=True, ncol=5)
555+
556+
legend_elements = [Line2D([0], [0], marker='o', color='w', label='p=0.5', markerfacecolor=cmap[1], markersize=8),
557+
Line2D([0], [0], marker='o', color='w', label='p=0.2', markerfacecolor=cmap[0], markersize=8),
558+
Line2D([0], [0], marker='o', color='w', label='p=0.8', markerfacecolor=cmap[2], markersize=8)]
559+
legend2 = plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.1, -0.2), fancybox=True, shadow=True)
560+
fig.add_artist(legend2)
561+
562+
return axs
563+
564+
565+
def plot_psychometric_curve(df, subject, one):
566+
df = df.drop_duplicates('date').reset_index(drop=True)
567+
sess_path = Path(df.iloc[-1]["session_path"])
568+
trials = load_trials(sess_path, one)
569+
570+
fig, ax1 = plt.subplots(figsize=(8, 6))
571+
572+
training.plot_psychometric(trials, ax=ax1, title=f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}')
573+
574+
return ax1
575+
576+
577+
def plot_over_days(df, subject, y1, y2=None, ax=None, legend=True, title=True, training_lines=True):
431578

432579
if ax is None:
433580
fig, ax1 = plt.subplots(figsize=(12, 6))
434581
else:
435582
ax1 = ax
436583

437-
ax2 = ax1.twinx()
438-
439584
dates = [datetime.strptime(dat, '%Y-%m-%d') for dat in df['date']]
440-
ax1.plot(dates, df[y1['column']], 'k')
441-
ax1.scatter(dates, df[y1['column']], c='k')
585+
if y1['join']:
586+
ax1.plot(dates, df[y1['column']], color=y1['color'])
587+
ax1.scatter(dates, df[y1['column']], color=y1['color'])
442588
ax1.set_ylabel(y1['title'])
443589
ax1.set_ylim(y1['lim'])
444590

445-
ax2.plot(dates, df[y2['column']], 'r')
446-
ax2.scatter(dates, df[y2['column']], c='r')
447-
ax2.set_ylabel(y2['title'])
448-
ax2.yaxis.label.set_color('r')
449-
ax2.tick_params(axis='y', colors='r')
450-
ax2.set_ylim(y2['lim'])
451-
if y2['log']:
452-
ax2.set_yscale('log')
591+
if y2 is not None:
592+
ax2 = ax1.twinx()
593+
if y2['join']:
594+
ax2.plot(dates, df[y2['column']], color=y2['color'])
595+
ax2.scatter(dates, df[y2['column']], color=y2['color'])
596+
ax2.set_ylabel(y2['title'])
597+
ax2.yaxis.label.set_color(y2['color'])
598+
ax2.tick_params(axis='y', colors=y2['color'])
599+
ax2.set_ylim(y2['lim'])
600+
if y2['log']:
601+
ax2.set_yscale('log')
602+
603+
ax2.spines['right'].set_visible(False)
604+
ax2.spines['top'].set_visible(False)
605+
ax2.spines['left'].set_visible(False)
453606

454607
month_format = mdates.DateFormatter('%b %Y')
455608
month_locator = mdates.MonthLocator()
@@ -462,20 +615,20 @@ def plot_over_days(df, y1, y2, subject, ax=None):
462615
ax1.spines['left'].set_visible(False)
463616
ax1.spines['right'].set_visible(False)
464617
ax1.spines['top'].set_visible(False)
465-
ax2.spines['right'].set_visible(False)
466-
ax2.spines['top'].set_visible(False)
467-
ax2.spines['left'].set_visible(False)
468618

469-
ax1 = add_training_lines(df, ax1)
619+
if training_lines:
620+
ax1 = add_training_lines(df, ax1)
470621

471-
ax1.set_title(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}')
472-
box = ax1.get_position()
473-
ax1.set_position([box.x0, box.y0 + box.height * 0.1,
474-
box.width, box.height * 0.9])
622+
if title:
623+
ax1.set_title(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}')
475624

476625
# Put a legend below current axis
477-
ax1.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1),
478-
fancybox=True, shadow=True, ncol=5)
626+
box = ax1.get_position()
627+
ax1.set_position([box.x0, box.y0 + box.height * 0.1,
628+
box.width, box.height * 0.9])
629+
if legend:
630+
ax1.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1),
631+
fancybox=True, shadow=True, ncol=5)
479632

480633
return ax1
481634

@@ -554,6 +707,8 @@ def make_plots(session_path, one, df=None, save=False, upload=False):
554707
ax1 = plot_trial_count_and_session_duration(df, subject)
555708
ax2 = plot_performance_easy_median_reaction_time(df, subject)
556709
ax3 = plot_heatmap_performance_over_days(df, subject)
710+
ax4 = plot_fit_params(df, subject)
711+
ax5 = plot_psychometric_curve(df, subject, one)
557712

558713
outputs = []
559714
if save:
@@ -570,6 +725,14 @@ def make_plots(session_path, one, df=None, save=False, upload=False):
570725
outputs.append(save_name)
571726
ax3.get_figure().savefig(save_name, bbox_inches='tight')
572727

728+
save_name = save_path.joinpath('subj_psychometric_fit_params.png')
729+
outputs.append(save_name)
730+
ax4[0].get_figure().savefig(save_name, bbox_inches='tight')
731+
732+
save_name = save_path.joinpath('subj_psychometric_curve.png')
733+
outputs.append(save_name)
734+
ax5.get_figure().savefig(save_name, bbox_inches='tight')
735+
573736
if upload:
574737
subj = one.alyx.rest('subjects', 'list', nickname=subject)[0]
575738
snp = ReportSnapshot(session_path, subj['id'], content_type='subject', one=one)

0 commit comments

Comments
 (0)