Skip to content

Commit 5d595a7

Browse files
committed
brain_region plot and cdf amplitude plot
1 parent 95a0553 commit 5d595a7

File tree

3 files changed

+90
-3
lines changed

3 files changed

+90
-3
lines changed

brainbox/ephys_plots.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import numpy as np
22
from matplotlib import cm
3-
3+
import matplotlib.pyplot as plt
44
from brainbox.plot_base import (ImagePlot, ScatterPlot, ProbePlot, LinePlot, plot_line,
55
plot_image, plot_probe, plot_scatter, arrange_channels2banks)
66
from brainbox.processing import bincount2D, compute_cluster_average
7+
from ibllib.atlas.regions import BrainRegions
78

89

910
def image_lfp_spectrum_plot(lfp_power, lfp_freq, chn_coords, chn_inds, freq_range=(0, 300),
@@ -372,3 +373,78 @@ def line_amp_plot(spike_amps, spike_depths, spike_times, chn_coords, d_bin=10, d
372373
fig, ax = plot_line(data.convert2dict())
373374
return data.convert2dict(), fig, ax
374375
return data
376+
377+
378+
def plot_brain_regions(channel_ids, channel_depths=None, brain_regions=None, display=True, ax=None):
379+
380+
if channel_depths is not None:
381+
assert channel_ids.shape[0] == channel_depths.shape[0]
382+
383+
br = brain_regions or BrainRegions()
384+
385+
region_info = br.get(channel_ids)
386+
boundaries = np.where(np.diff(region_info.id) != 0)[0]
387+
boundaries = np.r_[0, boundaries, region_info.id.shape[0] - 1]
388+
389+
regions = np.c_[boundaries[0:-1], boundaries[1:]]
390+
if channel_depths is not None:
391+
regions = channel_depths[regions]
392+
region_labels = np.c_[np.mean(regions, axis=1), region_info.acronym[boundaries[1:]]]
393+
region_colours = region_info.rgb[boundaries[1:]]
394+
395+
if display:
396+
if ax is None:
397+
fig, ax = plt.subplots()
398+
399+
for reg, col in zip(regions, region_colours):
400+
height = np.abs(reg[1] - reg[0])
401+
color = col / 255
402+
ax.bar(x=0.5, height=height, width=1, color=color, bottom=reg[0], edgecolor='w')
403+
ax.set_yticks(region_labels[:, 0].astype(int))
404+
ax.yaxis.set_tick_params(labelsize=8)
405+
ax.get_xaxis().set_visible(False)
406+
ax.set_yticklabels(region_labels[:, 1])
407+
ax.spines['right'].set_visible(False)
408+
ax.spines['top'].set_visible(False)
409+
ax.spines['bottom'].set_visible(False)
410+
411+
return fig, ax
412+
else:
413+
return regions, region_labels, region_colours
414+
415+
416+
def plot_cdf(spike_amps, spike_depths, spike_times, n_amp_bins=10, d_bin=40, amp_range=None, d_range=None,
417+
display=False, cmap='hot'):
418+
419+
amp_range = amp_range or np.quantile(spike_amps, (0, 0.9))
420+
amp_bins = np.linspace(amp_range[0], amp_range[1], n_amp_bins)
421+
d_range = d_range or [0, 3840]
422+
depth_bins = np.arange(d_range[0], d_range[1] + d_bin, d_bin)
423+
t_bin = np.max(spike_times)
424+
425+
def histc(x, bins):
426+
map_to_bins = np.digitize(x, bins) # Get indices of the bins to which each value in input array belongs.
427+
res = np.zeros(bins.shape)
428+
429+
for el in map_to_bins:
430+
res[el - 1] += 1 # Increment appropriate bin.
431+
return res
432+
433+
cdfs = np.empty((len(depth_bins) - 1, n_amp_bins))
434+
for d in range(len(depth_bins) - 1):
435+
spikes = np.bitwise_and(spike_depths > depth_bins[d], spike_depths <= depth_bins[d + 1])
436+
h = histc(spike_amps[spikes], amp_bins) / t_bin
437+
hcsum = np.cumsum(h[::-1])
438+
cdfs[d, :] = hcsum[::-1]
439+
440+
cdfs[cdfs == 0] = np.nan
441+
442+
data = ImagePlot(cdfs.T, x=amp_bins * 1e6, y=depth_bins[:-1], cmap=cmap)
443+
data.set_labels(title='Cumulative Amplitude', xlabel='Spike amplitude (uV)',
444+
ylabel='Distance from probe tip (um)', clabel='Firing Rate (Hz)')
445+
446+
if display:
447+
fig, ax = plot_image(data.convert2dict(), fig_kwargs={'figsize': [3, 7]})
448+
return data.convert2dict(), fig, ax
449+
450+
return data

brainbox/task/passive.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def get_stim_aligned_activity(stim_events, spike_times, spike_depths, z_score_fl
206206
base_intervals = np.c_[stim_times - base_stim, stim_times - pre_stim]
207207
out_intervals = stim_intervals[:, 1] > times[-1]
208208

209-
idx_stim = np.searchsorted(times, stim_intervals)[np.invert(out_intervals)]
210-
idx_base = np.searchsorted(times, base_intervals)[np.invert(out_intervals)]
209+
idx_stim = np.searchsorted(times, stim_intervals, side='right')[np.invert(out_intervals)]
210+
idx_base = np.searchsorted(times, base_intervals, side='right')[np.invert(out_intervals)]
211211

212212
stim_trials = np.zeros((depths.shape[0], n_bins, idx_stim.shape[0]))
213213
noise_trials = np.zeros((depths.shape[0], n_bins_base, idx_stim.shape[0]))

ibllib/atlas/regions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,17 @@ def _mapping_from_regions_list(self, new_map, lateralize=False):
149149
mapind = mapind[iregion]
150150
return mapind
151151

152+
def remap(self, region_ids, source_map='Allen', target_map='Beryl'):
153+
"""
154+
Remap atlas regions ids from source map to target map
155+
:param region_ids:
156+
:param source_map:
157+
:param target_map:
158+
:return:
159+
"""
160+
_, inds = ismember(region_ids, self.id[self.mappings[source_map]])
161+
return self.id[self.mappings[target_map][inds]]
162+
152163

153164
def regions_from_allen_csv():
154165
"""

0 commit comments

Comments
 (0)