Skip to content

Commit d0e7389

Browse files
committed
Merge branch 'develop' into single_source_version
2 parents 6b4ae85 + 36a8f05 commit d0e7389

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+2600
-698
lines changed

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
include ibllib/atlas/allen_structure_tree.csv
22
include ibllib/atlas/beryl.npy
3+
include ibllib/atlas/cosmos.npy
34
include ibllib/io/extractors/extractor_types.json
45
include brainbox/tests/wheel_test.p
56
recursive-include brainbox/tests/fixtures *

brainbox/behavior/training.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
import logging
21
from one.api import ONE
32
import datetime
43
import re
54
import numpy as np
65
from iblutil.util import Bunch
76
import brainbox.behavior.pyschofit as psy
8-
9-
logger = logging.getLogger('ibllib')
7+
import logging
8+
_logger = logging.getLogger('ibllib')
109

1110

1211
def get_lab_training_status(lab, date=None, details=True, one=None):
@@ -110,7 +109,7 @@ def get_sessions(subj, date=None, one=None):
110109

111110
# If still 0 sessions then return with warning
112111
if len(sessions) == 0:
113-
logger.warning(f"No training sessions detected for {subj}")
112+
_logger.warning(f"No training sessions detected for {subj}")
114113
return [None] * 4
115114

116115
trials = Bunch()
@@ -274,30 +273,30 @@ def display_status(subj, sess_dates, status, perf_easy=None, n_trials=None, psyc
274273
"""
275274

276275
if perf_easy is None:
277-
logger.info(f"\n{subj} : {status} \nSession dates=[{sess_dates[0]}, {sess_dates[1]}, "
278-
f"{sess_dates[2]}]")
276+
print(f"\n{subj} : {status} \nSession dates=[{sess_dates[0]}, {sess_dates[1]}, "
277+
f"{sess_dates[2]}]")
279278
elif psych_20 is None:
280-
logger.info(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, "
281-
f"Perf easy={[np.around(pe,2) for pe in perf_easy]}, "
282-
f"N trials={[nt for nt in n_trials]} "
283-
f"\nPsych fit over last 3 sessions: "
284-
f"bias={np.around(psych[0],2)}, thres={np.around(psych[1],2)}, "
285-
f"lapse_low={np.around(psych[2],2)}, lapse_high={np.around(psych[3],2)} "
286-
f"\nMedian reaction time at 0 contrast over last 3 sessions = "
287-
f"{np.around(rt,2)}")
279+
print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, "
280+
f"Perf easy={[np.around(pe,2) for pe in perf_easy]}, "
281+
f"N trials={[nt for nt in n_trials]} "
282+
f"\nPsych fit over last 3 sessions: "
283+
f"bias={np.around(psych[0],2)}, thres={np.around(psych[1],2)}, "
284+
f"lapse_low={np.around(psych[2],2)}, lapse_high={np.around(psych[3],2)} "
285+
f"\nMedian reaction time at 0 contrast over last 3 sessions = "
286+
f"{np.around(rt,2)}")
288287

289288
else:
290-
logger.info(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, "
291-
f"Perf easy={[np.around(pe,2) for pe in perf_easy]}, "
292-
f"N trials={[nt for nt in n_trials]} "
293-
f"\nPsych fit over last 3 sessions (20): "
294-
f"bias={np.around(psych_20[0],2)}, thres={np.around(psych_20[1],2)}, "
295-
f"lapse_low={np.around(psych_20[2],2)}, lapse_high={np.around(psych_20[3],2)} "
296-
f"\nPsych fit over last 3 sessions (80): bias={np.around(psych_80[0],2)}, "
297-
f"thres={np.around(psych_80[1],2)}, lapse_low={np.around(psych_80[2],2)}, "
298-
f"lapse_high={np.around(psych_80[3],2)} "
299-
f"\nMedian reaction time at 0 contrast over last 3 sessions = "
300-
f"{np.around(rt, 2)}")
289+
print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, "
290+
f"Perf easy={[np.around(pe,2) for pe in perf_easy]}, "
291+
f"N trials={[nt for nt in n_trials]} "
292+
f"\nPsych fit over last 3 sessions (20): "
293+
f"bias={np.around(psych_20[0],2)}, thres={np.around(psych_20[1],2)}, "
294+
f"lapse_low={np.around(psych_20[2],2)}, lapse_high={np.around(psych_20[3],2)} "
295+
f"\nPsych fit over last 3 sessions (80): bias={np.around(psych_80[0],2)}, "
296+
f"thres={np.around(psych_80[1],2)}, lapse_low={np.around(psych_80[2],2)}, "
297+
f"lapse_high={np.around(psych_80[3],2)} "
298+
f"\nMedian reaction time at 0 contrast over last 3 sessions = "
299+
f"{np.around(rt, 2)}")
301300

302301

303302
def concatenate_trials(trials):

brainbox/ephys_plots.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import numpy as np
22
from matplotlib import cm
3-
3+
import matplotlib.pyplot as plt
44
from brainbox.plot_base import (ImagePlot, ScatterPlot, ProbePlot, LinePlot, plot_line,
55
plot_image, plot_probe, plot_scatter, arrange_channels2banks)
66
from brainbox.processing import bincount2D, compute_cluster_average
7+
from ibllib.atlas.regions import BrainRegions
78

89

910
def image_lfp_spectrum_plot(lfp_power, lfp_freq, chn_coords, chn_inds, freq_range=(0, 300),
@@ -372,3 +373,102 @@ def line_amp_plot(spike_amps, spike_depths, spike_times, chn_coords, d_bin=10, d
372373
fig, ax = plot_line(data.convert2dict())
373374
return data.convert2dict(), fig, ax
374375
return data
376+
377+
378+
def plot_brain_regions(channel_ids, channel_depths=None, brain_regions=None, display=True, ax=None):
379+
"""
380+
Plot brain regions along probe, if channel depths is provided will plot along depth otherwise along channel idx
381+
:param channel_ids: atlas ids for each channel
382+
:param channel_depths: depth along probe for each channel
383+
:param brain_regions: BrainRegions object
384+
:param display: whether to output plot
385+
:param ax: axis to plot on
386+
:return:
387+
"""
388+
389+
if channel_depths is not None:
390+
assert channel_ids.shape[0] == channel_depths.shape[0]
391+
392+
br = brain_regions or BrainRegions()
393+
394+
region_info = br.get(channel_ids)
395+
boundaries = np.where(np.diff(region_info.id) != 0)[0]
396+
boundaries = np.r_[0, boundaries, region_info.id.shape[0] - 1]
397+
398+
regions = np.c_[boundaries[0:-1], boundaries[1:]]
399+
if channel_depths is not None:
400+
regions = channel_depths[regions]
401+
region_labels = np.c_[np.mean(regions, axis=1), region_info.acronym[boundaries[1:]]]
402+
region_colours = region_info.rgb[boundaries[1:]]
403+
404+
if display:
405+
if ax is None:
406+
fig, ax = plt.subplots()
407+
else:
408+
fig = ax.get_figure()
409+
410+
for reg, col in zip(regions, region_colours):
411+
height = np.abs(reg[1] - reg[0])
412+
color = col / 255
413+
ax.bar(x=0.5, height=height, width=1, color=color, bottom=reg[0], edgecolor='w')
414+
ax.set_yticks(region_labels[:, 0].astype(int))
415+
ax.yaxis.set_tick_params(labelsize=8)
416+
ax.get_xaxis().set_visible(False)
417+
ax.set_yticklabels(region_labels[:, 1])
418+
ax.spines['right'].set_visible(False)
419+
ax.spines['top'].set_visible(False)
420+
ax.spines['bottom'].set_visible(False)
421+
422+
return fig, ax
423+
else:
424+
return regions, region_labels, region_colours
425+
426+
427+
def plot_cdf(spike_amps, spike_depths, spike_times, n_amp_bins=10, d_bin=40, amp_range=None, d_range=None,
428+
display=False, cmap='hot'):
429+
"""
430+
Plot cumulative amplitude of spikes across depth
431+
:param spike_amps:
432+
:param spike_depths:
433+
:param spike_times:
434+
:param n_amp_bins: number of amplitude bins to use
435+
:param d_bin: the value of the depth bins in um (default is 40 um)
436+
:param amp_range: amp range to use [amp_min, amp_max], if not given automatically computed from spike_amps
437+
:param d_range: depth range to use, by default [0, 3840]
438+
:param display: whether or not to display plot
439+
:param cmap:
440+
:return:
441+
"""
442+
443+
amp_range = amp_range or np.quantile(spike_amps, (0, 0.9))
444+
amp_bins = np.linspace(amp_range[0], amp_range[1], n_amp_bins)
445+
d_range = d_range or [0, 3840]
446+
depth_bins = np.arange(d_range[0], d_range[1] + d_bin, d_bin)
447+
t_bin = np.max(spike_times)
448+
449+
def histc(x, bins):
450+
map_to_bins = np.digitize(x, bins) # Get indices of the bins to which each value in input array belongs.
451+
res = np.zeros(bins.shape)
452+
453+
for el in map_to_bins:
454+
res[el - 1] += 1 # Increment appropriate bin.
455+
return res
456+
457+
cdfs = np.empty((len(depth_bins) - 1, n_amp_bins))
458+
for d in range(len(depth_bins) - 1):
459+
spikes = np.bitwise_and(spike_depths > depth_bins[d], spike_depths <= depth_bins[d + 1])
460+
h = histc(spike_amps[spikes], amp_bins) / t_bin
461+
hcsum = np.cumsum(h[::-1])
462+
cdfs[d, :] = hcsum[::-1]
463+
464+
cdfs[cdfs == 0] = np.nan
465+
466+
data = ImagePlot(cdfs.T, x=amp_bins * 1e6, y=depth_bins[:-1], cmap=cmap)
467+
data.set_labels(title='Cumulative Amplitude', xlabel='Spike amplitude (uV)',
468+
ylabel='Distance from probe tip (um)', clabel='Firing Rate (Hz)')
469+
470+
if display:
471+
fig, ax = plot_image(data.convert2dict(), fig_kwargs={'figsize': [3, 7]})
472+
return data.convert2dict(), fig, ax
473+
474+
return data

0 commit comments

Comments
 (0)