|
1 | 1 | import numpy as np |
2 | 2 | from matplotlib import cm |
3 | | - |
| 3 | +import matplotlib.pyplot as plt |
4 | 4 | from brainbox.plot_base import (ImagePlot, ScatterPlot, ProbePlot, LinePlot, plot_line, |
5 | 5 | plot_image, plot_probe, plot_scatter, arrange_channels2banks) |
6 | 6 | from brainbox.processing import bincount2D, compute_cluster_average |
| 7 | +from ibllib.atlas.regions import BrainRegions |
7 | 8 |
|
8 | 9 |
|
9 | 10 | def image_lfp_spectrum_plot(lfp_power, lfp_freq, chn_coords, chn_inds, freq_range=(0, 300), |
@@ -372,3 +373,100 @@ def line_amp_plot(spike_amps, spike_depths, spike_times, chn_coords, d_bin=10, d |
372 | 373 | fig, ax = plot_line(data.convert2dict()) |
373 | 374 | return data.convert2dict(), fig, ax |
374 | 375 | return data |
| 376 | + |
| 377 | + |
| 378 | +def plot_brain_regions(channel_ids, channel_depths=None, brain_regions=None, display=True, ax=None): |
| 379 | + """ |
| 380 | + Plot brain regions along probe, if channel depths is provided will plot along depth otherwise along channel idx |
| 381 | + :param channel_ids: atlas ids for each channel |
| 382 | + :param channel_depths: depth along probe for each channel |
| 383 | + :param brain_regions: BrainRegions object |
| 384 | + :param display: whether to output plot |
| 385 | + :param ax: axis to plot on |
| 386 | + :return: |
| 387 | + """ |
| 388 | + |
| 389 | + if channel_depths is not None: |
| 390 | + assert channel_ids.shape[0] == channel_depths.shape[0] |
| 391 | + |
| 392 | + br = brain_regions or BrainRegions() |
| 393 | + |
| 394 | + region_info = br.get(channel_ids) |
| 395 | + boundaries = np.where(np.diff(region_info.id) != 0)[0] |
| 396 | + boundaries = np.r_[0, boundaries, region_info.id.shape[0] - 1] |
| 397 | + |
| 398 | + regions = np.c_[boundaries[0:-1], boundaries[1:]] |
| 399 | + if channel_depths is not None: |
| 400 | + regions = channel_depths[regions] |
| 401 | + region_labels = np.c_[np.mean(regions, axis=1), region_info.acronym[boundaries[1:]]] |
| 402 | + region_colours = region_info.rgb[boundaries[1:]] |
| 403 | + |
| 404 | + if display: |
| 405 | + if ax is None: |
| 406 | + fig, ax = plt.subplots() |
| 407 | + |
| 408 | + for reg, col in zip(regions, region_colours): |
| 409 | + height = np.abs(reg[1] - reg[0]) |
| 410 | + color = col / 255 |
| 411 | + ax.bar(x=0.5, height=height, width=1, color=color, bottom=reg[0], edgecolor='w') |
| 412 | + ax.set_yticks(region_labels[:, 0].astype(int)) |
| 413 | + ax.yaxis.set_tick_params(labelsize=8) |
| 414 | + ax.get_xaxis().set_visible(False) |
| 415 | + ax.set_yticklabels(region_labels[:, 1]) |
| 416 | + ax.spines['right'].set_visible(False) |
| 417 | + ax.spines['top'].set_visible(False) |
| 418 | + ax.spines['bottom'].set_visible(False) |
| 419 | + |
| 420 | + return fig, ax |
| 421 | + else: |
| 422 | + return regions, region_labels, region_colours |
| 423 | + |
| 424 | + |
| 425 | +def plot_cdf(spike_amps, spike_depths, spike_times, n_amp_bins=10, d_bin=40, amp_range=None, d_range=None, |
| 426 | + display=False, cmap='hot'): |
| 427 | + """ |
| 428 | + Plot cumulative amplitude of spikes across depth |
| 429 | + :param spike_amps: |
| 430 | + :param spike_depths: |
| 431 | + :param spike_times: |
| 432 | + :param n_amp_bins: number of amplitude bins to use |
| 433 | + :param d_bin: the value of the depth bins in um (default is 40 um) |
| 434 | + :param amp_range: amp range to use [amp_min, amp_max], if not given automatically computed from spike_amps |
| 435 | + :param d_range: depth range to use, by default [0, 3840] |
| 436 | + :param display: whether or not to display plot |
| 437 | + :param cmap: |
| 438 | + :return: |
| 439 | + """ |
| 440 | + |
| 441 | + amp_range = amp_range or np.quantile(spike_amps, (0, 0.9)) |
| 442 | + amp_bins = np.linspace(amp_range[0], amp_range[1], n_amp_bins) |
| 443 | + d_range = d_range or [0, 3840] |
| 444 | + depth_bins = np.arange(d_range[0], d_range[1] + d_bin, d_bin) |
| 445 | + t_bin = np.max(spike_times) |
| 446 | + |
| 447 | + def histc(x, bins): |
| 448 | + map_to_bins = np.digitize(x, bins) # Get indices of the bins to which each value in input array belongs. |
| 449 | + res = np.zeros(bins.shape) |
| 450 | + |
| 451 | + for el in map_to_bins: |
| 452 | + res[el - 1] += 1 # Increment appropriate bin. |
| 453 | + return res |
| 454 | + |
| 455 | + cdfs = np.empty((len(depth_bins) - 1, n_amp_bins)) |
| 456 | + for d in range(len(depth_bins) - 1): |
| 457 | + spikes = np.bitwise_and(spike_depths > depth_bins[d], spike_depths <= depth_bins[d + 1]) |
| 458 | + h = histc(spike_amps[spikes], amp_bins) / t_bin |
| 459 | + hcsum = np.cumsum(h[::-1]) |
| 460 | + cdfs[d, :] = hcsum[::-1] |
| 461 | + |
| 462 | + cdfs[cdfs == 0] = np.nan |
| 463 | + |
| 464 | + data = ImagePlot(cdfs.T, x=amp_bins * 1e6, y=depth_bins[:-1], cmap=cmap) |
| 465 | + data.set_labels(title='Cumulative Amplitude', xlabel='Spike amplitude (uV)', |
| 466 | + ylabel='Distance from probe tip (um)', clabel='Firing Rate (Hz)') |
| 467 | + |
| 468 | + if display: |
| 469 | + fig, ax = plot_image(data.convert2dict(), fig_kwargs={'figsize': [3, 7]}) |
| 470 | + return data.convert2dict(), fig, ax |
| 471 | + |
| 472 | + return data |
0 commit comments