Skip to content

Commit 6996db4

Browse files
committed
Merge branch 'brain_region_plots' into multiparts
2 parents d5387d9 + ae39f71 commit 6996db4

File tree

4 files changed

+119
-3
lines changed

4 files changed

+119
-3
lines changed

brainbox/ephys_plots.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import numpy as np
22
from matplotlib import cm
3-
3+
import matplotlib.pyplot as plt
44
from brainbox.plot_base import (ImagePlot, ScatterPlot, ProbePlot, LinePlot, plot_line,
55
plot_image, plot_probe, plot_scatter, arrange_channels2banks)
66
from brainbox.processing import bincount2D, compute_cluster_average
7+
from ibllib.atlas.regions import BrainRegions
78

89

910
def image_lfp_spectrum_plot(lfp_power, lfp_freq, chn_coords, chn_inds, freq_range=(0, 300),
@@ -372,3 +373,100 @@ def line_amp_plot(spike_amps, spike_depths, spike_times, chn_coords, d_bin=10, d
372373
fig, ax = plot_line(data.convert2dict())
373374
return data.convert2dict(), fig, ax
374375
return data
376+
377+
378+
def plot_brain_regions(channel_ids, channel_depths=None, brain_regions=None, display=True, ax=None):
379+
"""
380+
Plot brain regions along probe, if channel depths is provided will plot along depth otherwise along channel idx
381+
:param channel_ids: atlas ids for each channel
382+
:param channel_depths: depth along probe for each channel
383+
:param brain_regions: BrainRegions object
384+
:param display: whether to output plot
385+
:param ax: axis to plot on
386+
:return:
387+
"""
388+
389+
if channel_depths is not None:
390+
assert channel_ids.shape[0] == channel_depths.shape[0]
391+
392+
br = brain_regions or BrainRegions()
393+
394+
region_info = br.get(channel_ids)
395+
boundaries = np.where(np.diff(region_info.id) != 0)[0]
396+
boundaries = np.r_[0, boundaries, region_info.id.shape[0] - 1]
397+
398+
regions = np.c_[boundaries[0:-1], boundaries[1:]]
399+
if channel_depths is not None:
400+
regions = channel_depths[regions]
401+
region_labels = np.c_[np.mean(regions, axis=1), region_info.acronym[boundaries[1:]]]
402+
region_colours = region_info.rgb[boundaries[1:]]
403+
404+
if display:
405+
if ax is None:
406+
fig, ax = plt.subplots()
407+
408+
for reg, col in zip(regions, region_colours):
409+
height = np.abs(reg[1] - reg[0])
410+
color = col / 255
411+
ax.bar(x=0.5, height=height, width=1, color=color, bottom=reg[0], edgecolor='w')
412+
ax.set_yticks(region_labels[:, 0].astype(int))
413+
ax.yaxis.set_tick_params(labelsize=8)
414+
ax.get_xaxis().set_visible(False)
415+
ax.set_yticklabels(region_labels[:, 1])
416+
ax.spines['right'].set_visible(False)
417+
ax.spines['top'].set_visible(False)
418+
ax.spines['bottom'].set_visible(False)
419+
420+
return fig, ax
421+
else:
422+
return regions, region_labels, region_colours
423+
424+
425+
def plot_cdf(spike_amps, spike_depths, spike_times, n_amp_bins=10, d_bin=40, amp_range=None, d_range=None,
426+
display=False, cmap='hot'):
427+
"""
428+
Plot cumulative amplitude of spikes across depth
429+
:param spike_amps:
430+
:param spike_depths:
431+
:param spike_times:
432+
:param n_amp_bins: number of amplitude bins to use
433+
:param d_bin: the value of the depth bins in um (default is 40 um)
434+
:param amp_range: amp range to use [amp_min, amp_max], if not given automatically computed from spike_amps
435+
:param d_range: depth range to use, by default [0, 3840]
436+
:param display: whether or not to display plot
437+
:param cmap:
438+
:return:
439+
"""
440+
441+
amp_range = amp_range or np.quantile(spike_amps, (0, 0.9))
442+
amp_bins = np.linspace(amp_range[0], amp_range[1], n_amp_bins)
443+
d_range = d_range or [0, 3840]
444+
depth_bins = np.arange(d_range[0], d_range[1] + d_bin, d_bin)
445+
t_bin = np.max(spike_times)
446+
447+
def histc(x, bins):
448+
map_to_bins = np.digitize(x, bins) # Get indices of the bins to which each value in input array belongs.
449+
res = np.zeros(bins.shape)
450+
451+
for el in map_to_bins:
452+
res[el - 1] += 1 # Increment appropriate bin.
453+
return res
454+
455+
cdfs = np.empty((len(depth_bins) - 1, n_amp_bins))
456+
for d in range(len(depth_bins) - 1):
457+
spikes = np.bitwise_and(spike_depths > depth_bins[d], spike_depths <= depth_bins[d + 1])
458+
h = histc(spike_amps[spikes], amp_bins) / t_bin
459+
hcsum = np.cumsum(h[::-1])
460+
cdfs[d, :] = hcsum[::-1]
461+
462+
cdfs[cdfs == 0] = np.nan
463+
464+
data = ImagePlot(cdfs.T, x=amp_bins * 1e6, y=depth_bins[:-1], cmap=cmap)
465+
data.set_labels(title='Cumulative Amplitude', xlabel='Spike amplitude (uV)',
466+
ylabel='Distance from probe tip (um)', clabel='Firing Rate (Hz)')
467+
468+
if display:
469+
fig, ax = plot_image(data.convert2dict(), fig_kwargs={'figsize': [3, 7]})
470+
return data.convert2dict(), fig, ax
471+
472+
return data

brainbox/task/passive.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def get_stim_aligned_activity(stim_events, spike_times, spike_depths, z_score_fl
206206
base_intervals = np.c_[stim_times - base_stim, stim_times - pre_stim]
207207
out_intervals = stim_intervals[:, 1] > times[-1]
208208

209-
idx_stim = np.searchsorted(times, stim_intervals)[np.invert(out_intervals)]
210-
idx_base = np.searchsorted(times, base_intervals)[np.invert(out_intervals)]
209+
idx_stim = np.searchsorted(times, stim_intervals, side='right')[np.invert(out_intervals)]
210+
idx_base = np.searchsorted(times, base_intervals, side='right')[np.invert(out_intervals)]
211211

212212
stim_trials = np.zeros((depths.shape[0], n_bins, idx_stim.shape[0]))
213213
noise_trials = np.zeros((depths.shape[0], n_bins_base, idx_stim.shape[0]))

ibllib/atlas/regions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,17 @@ def _mapping_from_regions_list(self, new_map, lateralize=False):
149149
mapind = mapind[iregion]
150150
return mapind
151151

152+
def remap(self, region_ids, source_map='Allen', target_map='Beryl'):
153+
"""
154+
Remap atlas regions ids from source map to target map
155+
:param region_ids: atlas ids to map
156+
:param source_map: map name which original region_ids are in
157+
:param target_map: map name onto which to map
158+
:return:
159+
"""
160+
_, inds = ismember(region_ids, self.id[self.mappings[source_map]])
161+
return self.id[self.mappings[target_map][inds]]
162+
152163

153164
def regions_from_allen_csv():
154165
"""

ibllib/tests/test_atlas.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ def test_mappings_not_lateralized(self):
5757
inds_[0] = 0
5858
assert np.all(inds == inds_)
5959

60+
def test_remap(self):
61+
# Test mapping atlas ids from one map to another
62+
atlas_id = np.array([463, 685]) # CA3 and PO
63+
cosmos_atlas_id = self.brs.remap(atlas_id, source_map='Allen', target_map='Cosmos')
64+
expectd_cosmos_id = [1089, 549] # HPF and TH
65+
assert np.all(cosmos_atlas_id == expectd_cosmos_id)
66+
6067

6168
class TestAtlasSlicesConversion(unittest.TestCase):
6269

0 commit comments

Comments
 (0)