diff --git a/nwbwidgets/analysis/placefields.py b/nwbwidgets/analysis/placefields.py new file mode 100644 index 00000000..30bb4c51 --- /dev/null +++ b/nwbwidgets/analysis/placefields.py @@ -0,0 +1,300 @@ +from functools import lru_cache + +import numpy as np +from scipy.ndimage.filters import gaussian_filter + + +def find_nearest(arr, tt): + """Used for picking out elements of a TimeSeries based on spike times + + Parameters + ---------- + arr: array-like + tt: array-like + + Returns + ------- + indices: array-like + + """ + arr = arr[arr > tt[0]] + arr = arr[arr < tt[-1]] + return np.searchsorted(tt, arr) + + +def smooth(y, box_pts): + """Moving average + + Parameters + ---------- + y: array-like + box_pts: int + + Returns + ------- + output: np.array(dtype=float) + + """ + box = np.ones(box_pts) / box_pts + return np.convolve(y, box, mode='same') + + +def compute_speed(pos, pos_tt, smooth_param=40): + """Compute boolean of whether the speed of the animal was above a threshold + for each time point + + Parameters + ---------- + pos: np.ndarray(dtype=float) + in meters + pos_tt: np.ndarray(dtype=float) + in seconds + smooth_param: float, optional + + Returns + ------- + running: np.ndarray(dtype=bool) + + """ + if len(pos.shape) > 1: + speed = np.hstack((0, np.sqrt(np.sum(np.diff(pos.T) ** 2, axis=0)) / np.diff(pos_tt))) + else: + speed = np.hstack((0, np.sqrt(np.diff(pos.T) ** 2) / np.diff(pos_tt))) + return smooth(speed, smooth_param) + + +def compute_2d_occupancy(pos, pos_tt, edges_x, edges_y, pixel_width, speed_thresh=0.03, velocity=None): + """Computes occupancy per bin in seconds + + Parameters + ---------- + pos: np.ndarray(dtype=float) + in meters + pos_tt: np.ndarray(dtype=float) + in seconds + edges_x: array-like + edges of histogram in meters + edges_y: array-like + edges of histogram in meters + pixel_width: array-like + speed_thresh: float, optional + in meters. Default = 3.0 cm/s + velocity: np.ndarray(dtype=float) + pre-computed velocity + + Returns + ------- + occupancy: np.ndarray(dtype=float) + in seconds + running: np.ndarray(dtype=bool) + + """ + + sampling_period = (np.max(pos_tt) - np.min(pos_tt)) / len(pos_tt) + np.seterr(invalid='ignore') + if velocity is None: + if pixel_width[1] is not int(1): + is_running = compute_speed(pos, pos_tt) > speed_thresh + else: + is_running = compute_speed(pos[:, 0], pos_tt) > speed_thresh + else: + is_running = np.linalg.norm(velocity) > speed_thresh + + run_pos = pos[is_running, :] + occupancy = np.histogram2d(run_pos[:, 1], + run_pos[:, 0], + [edges_y, edges_x])[0] * sampling_period # in seconds + + return occupancy, is_running + + +def compute_2d_n_spikes(pos, pos_tt, spikes, edges_x, edges_y, pixel_width, speed_thresh=0.03, velocity=None): + """Returns speed-gated position during spikes + + Parameters + ---------- + pos: np.ndarray(dtype=float) + (time x 2) in meters + pos_tt: np.ndarray(dtype=float) + (time,) in seconds + spikes: np.ndarray(dtype=float) + in seconds + edges_x: np.ndarray(dtype=float) + edges of histogram in meters + edges_y: np.ndarray(dtype=float) + edges of histogram in meters + pixel_width: array + speed_thresh: float + in meters. Default = 3.0 cm/s + velocity: np.ndarray(dtype=float) + pre-computed velocity + + Returns + ------- + """ + np.seterr(invalid='ignore') + if velocity is None: + if pixel_width[1] is not int(1): + is_running = compute_speed(pos, pos_tt) > speed_thresh + else: + is_running = compute_speed(pos[:, 0], pos_tt) > speed_thresh + else: + is_running = np.linalg.norm(velocity) > speed_thresh + + spike_pos_inds = find_nearest(spikes, pos_tt) + spike_pos_inds = spike_pos_inds[is_running[spike_pos_inds]] + pos_on_spikes = pos[spike_pos_inds, :] + + n_spikes = np.histogram2d(pos_on_spikes[:, 1], + pos_on_spikes[:, 0], + [edges_y, edges_x])[0] + + return n_spikes + + +def compute_2d_firing_rate(pos, pos_tt, spikes, + pixel_width, + speed_thresh=0.03, + gaussian_sd_x=0.0184, + gaussian_sd_y=0.0184, + x_start=None, x_stop=None, + y_start=None, y_stop=None, + velocity=None): + """Returns speed-gated occupancy and speed-gated and + Gaussian-filtered firing rate + + Parameters + ---------- + pos: np.ndarray(dtype=float) + (time x 2), in meters + pos_tt: np.ndarray(dtype=float) + (time,) in seconds + spikes: np.ndarray(dtype=float) + in seconds + pixel_width: array-like + speed_thresh: float, optional + in meters. Default = 3.0 cm/s + gaussian_sd_x: float, optional + width of gaussian kernel in x-dim, in meters. Default = 1.84 cm + gaussian_sd_y: float, optional + width of gaussian kernel in y-dim, in meters. Default = 1.84 cm + x_start: float, optional + x_stop: float, optional + y_start: float, optional + y_stop: float, optional + velocity: np.ndarray(dtype=float) + pre-computed velocity + + Returns + ------- + + occupancy: np.ndarray + in seconds + filtered_firing_rate: np.ndarray(shape=(cell, x, y), dtype=float) + in Hz + + """ + + x_start = np.nanmin(pos[:, 0]) if x_start is None else x_start + x_stop = np.nanmax(pos[:, 0]) if x_stop is None else x_stop + + y_start = np.nanmin(pos[:, 1]) if y_start is None else y_start + y_stop = np.nanmax(pos[:, 1]) if y_stop is None else y_stop + + edges_x = np.arange(x_start, x_stop, pixel_width[0]) + edges_y = np.arange(y_start, y_stop, pixel_width[1]) + + occupancy, running = compute_2d_occupancy(pos, pos_tt, edges_x, edges_y, pixel_width, speed_thresh, velocity) + + n_spikes = compute_2d_n_spikes(pos, pos_tt, spikes, edges_x, edges_y, pixel_width, speed_thresh, velocity) + + np.seterr(divide='ignore') + firing_rate = n_spikes / occupancy # in Hz + firing_rate[np.isnan(firing_rate)] = 0 # get rid of NaNs so convolution works + sigmas = [gaussian_sd_y / pixel_width[1], gaussian_sd_x / pixel_width[0]] + filtered_firing_rate = gaussian_filter(firing_rate, sigmas) + + # filter occupancy to create a mask so non-explored regions are nan'ed + sigmas_occ = [gaussian_sd_y / pixel_width[1] / 8, gaussian_sd_x / pixel_width[0] / 8] + filtered_occupancy = gaussian_filter(occupancy, sigmas_occ) + filtered_firing_rate[filtered_occupancy.astype('bool') < .00001] = np.nan + + return occupancy, filtered_firing_rate, [edges_x, edges_y] + + +def compute_1d_occupancy(pos, pos_tt, spatial_bins, sampling_rate, speed_thresh=0.03, velocity=None): + + np.seterr(invalid='ignore') + if velocity is None: + is_running = compute_speed(pos, pos_tt) > speed_thresh + else: + is_running = np.linalg.norm(velocity) > speed_thresh + + run_pos = pos[is_running, :] + finite_lin_pos = run_pos[np.isfinite(run_pos)] + + occupancy = np.histogram( + finite_lin_pos, bins=spatial_bins)[0][:-2] / sampling_rate + + return occupancy + + +def compute_linear_firing_rate(pos, pos_tt, spikes, gaussian_sd=0.0557, + spatial_bin_len=0.0168, speed_thresh=0.03, velocity=None): + """The occupancy and number of spikes, speed-gated, binned, and smoothed + over position + + Parameters + ---------- + pos: np.ndarray + linearized position + pos_tt: np.ndarray + sample times in seconds + spikes: np.ndarray + for a single cell in seconds + gaussian_sd: float (optional) + in meters. Default = 5.57 cm + spatial_bin_len: float (optional) + in meters. Default = 1.68 cm + speed_thresh: float (optional) + in m/s. Default = 0.03 + velocity: np.ndarray(dtype=float) + pre-computed velocity + + Returns + ------- + xx: np.ndarray + center of position bins in meters + occupancy: np.ndarray + time in each spatial bin in seconds, during appropriate trials and + while running + filtered_n_spikes: np.ndarray + number of spikes in each spatial bin, during appropriate trials, while + running, and processed with a Gaussian filter + + """ + spatial_bins = np.arange(np.nanmin(pos), np.nanmax(pos) + spatial_bin_len, spatial_bin_len) + + sampling_rate = len(pos_tt) / (np.nanmax(pos_tt) - np.nanmin(pos_tt)) + + occupancy = compute_1d_occupancy(pos, pos_tt, spatial_bins, sampling_rate, speed_thresh, velocity) + + np.seterr(invalid='ignore') + is_running = compute_speed(pos, pos_tt) > speed_thresh + + # find pos_tt bin associated with each spike + spike_pos_inds = find_nearest(spikes, pos_tt) + spike_pos_inds = spike_pos_inds[is_running[spike_pos_inds]] + pos_on_spikes = pos[spike_pos_inds] + finite_pos_on_spikes = pos_on_spikes[np.isfinite(pos_on_spikes)] + + n_spikes = np.histogram(finite_pos_on_spikes, bins=spatial_bins)[0][:-2] + + np.seterr(divide='ignore') + firing_rate = np.nan_to_num(n_spikes / occupancy) + + filtered_firing_rate = gaussian_filter( + firing_rate, gaussian_sd / spatial_bin_len) + xx = spatial_bins[:-3] + (spatial_bins[1] - spatial_bins[0]) / 2 + + return xx, occupancy, filtered_firing_rate diff --git a/nwbwidgets/ecephys.py b/nwbwidgets/ecephys.py index 6b5accd5..d036e881 100644 --- a/nwbwidgets/ecephys.py +++ b/nwbwidgets/ecephys.py @@ -1,11 +1,12 @@ import matplotlib.pyplot as plt import numpy as np import plotly.graph_objects as go -from plotly.colors import DEFAULT_PLOTLY_COLORS +import pynwb from ipywidgets import widgets, ValueWidget + from pynwb.ecephys import SpikeEventSeries, ElectricalSeries + from scipy.signal import stft -import pynwb from .base import fig2widget, lazy_tabs, render_dataframe from .timeseries import BaseGroupedTraceWidget diff --git a/nwbwidgets/placefield.py b/nwbwidgets/placefield.py new file mode 100644 index 00000000..a7c973d3 --- /dev/null +++ b/nwbwidgets/placefield.py @@ -0,0 +1,301 @@ +from functools import lru_cache + +import matplotlib.pyplot as plt +import numpy as np +import pynwb +from ipywidgets import widgets, BoundedFloatText, Dropdown, Checkbox, Layout + +from .analysis.placefields import compute_2d_firing_rate, compute_linear_firing_rate + +from .base import vis2widget +from .utils.widgets import interactive_output +from .utils.units import get_spike_times +from .utils.timeseries import get_timeseries_in_units, get_timeseries_tt + + +def route_placefield(spatial_series: pynwb.behavior.SpatialSeries): + if spatial_series.data.shape[1] == 2: + return PlaceFieldWidget(spatial_series) + elif spatial_series.data.shape[1] == 1: + return PlaceField1DWidget(spatial_series) + else: + print('Spatial series exceeds dimensionality for visualization') + return + + +class PlaceFieldWidget(widgets.HBox): + + def __init__(self, spatial_series: pynwb.behavior.SpatialSeries, + velocity: pynwb.TimeSeries = None, units = None, + **kwargs): + super().__init__() + if units is None: + self.units = spatial_series.get_ancestor('NWBFile').units + else: + self.units = units + self.pos_tt = get_timeseries_tt(spatial_series) + if velocity is not None: + self.velocity = velocity + self.disable = False + else: + self.velocity = None + self.disable = True + + self.get_position(spatial_series) + + bft_gaussian_x, bft_gaussian_y, bft_bin_num, bft_speed, dd_unit_select, cb_velocity = self.get_controls() + + self.controls = dict( + gaussian_sd_x=bft_gaussian_x, + gaussian_sd_y=bft_gaussian_y, + bin_num=bft_bin_num, + speed_thresh=bft_speed, + index=dd_unit_select, + use_velocity=cb_velocity + ) + + out_fig = interactive_output(self.do_rate_map, self.controls) + + self.children = [ + widgets.VBox([ + bft_gaussian_x, + bft_gaussian_y, + bft_bin_num, + bft_speed, + dd_unit_select, + cb_velocity, + ]), + vis2widget(out_fig) + ] + + def get_pixel_width(self, bin_num): + self.pixel_width = [(np.nanmax(self.pos) - np.nanmin(self.pos)) / bin_num] * 2 + + def get_position(self, spatial_series): + self.pos, self.unit = get_timeseries_in_units(spatial_series) + + def get_controls(self): + style = {'description_width': 'initial'} + bft_gaussian_x = BoundedFloatText(value=0.0184, min=0, max=99999, description='gaussian sd x (cm)', style=style) + bft_gaussian_y = BoundedFloatText(value=0.0184, min=0, max=99999, description='gaussian sd y (cm)', style=style) + bft_bin_num = BoundedFloatText(value=1000, min=0, max=99999, description='number of bins', style=style) + bft_speed = BoundedFloatText(value=0.03, min=0, max=99999, description='speed threshold (m/s)', style=style) + dd_unit_select = Dropdown(options=np.arange(len(self.units)), description='unit') + cb_velocity = Checkbox(value=False, description='use velocity', indent=False, disabled= self.disable) + + return bft_gaussian_x, bft_gaussian_y, bft_bin_num, bft_speed, dd_unit_select, cb_velocity + + def do_rate_map(self, index=0, speed_thresh=0.03, gaussian_sd_x=0.0184, gaussian_sd_y=0.0184, bin_num=1000, + use_velocity=False): + self.get_pixel_width(bin_num) + occupancy, filtered_firing_rate, [edges_x, edges_y] = self.compute_twodim_firing_rate(self.pixel_width[0], + index=index, + speed_thresh=speed_thresh, + gaussian_sd_x=gaussian_sd_x, + gaussian_sd_y=gaussian_sd_y, + use_velocity=use_velocity) + fig, ax = plt.subplots() + + im = ax.imshow(filtered_firing_rate, + extent=[edges_x[0], edges_x[-1], edges_y[0], edges_y[-1]], + aspect='equal') + ax.set_xlabel('x ({})'.format(self.unit)) + ax.set_ylabel('y ({})'.format(self.unit)) + + cbar = plt.colorbar(im) + cbar.ax.set_ylabel('firing rate (Hz)') + + return fig + + @lru_cache() + def compute_twodim_firing_rate(self, pixel_width, index=0, speed_thresh=0.03, gaussian_sd_x=0.0184, gaussian_sd_y=0.0184, + use_velocity=False): + tmin = min(self.pos_tt) + tmax = max(self.pos_tt) + spikes = get_spike_times(self.units, index, [tmin, tmax]) + if use_velocity == False: + occupancy, filtered_firing_rate, [edges_x, edges_y] = compute_2d_firing_rate(self.pos, self.pos_tt, spikes, + self.pixel_width, + speed_thresh=speed_thresh, + gaussian_sd_x=gaussian_sd_x, + gaussian_sd_y=gaussian_sd_y) + else: + occupancy, filtered_firing_rate, [edges_x, edges_y] = compute_2d_firing_rate(self.pos, self.pos_tt, spikes, + self.pixel_width, + speed_thresh=speed_thresh, + gaussian_sd_x=gaussian_sd_x, + gaussian_sd_y=gaussian_sd_y, + velocity=self.velocity) + return occupancy, filtered_firing_rate, [edges_x, edges_y] + + +class PlaceField1DWidget(widgets.HBox): + def __init__(self, spatial_series: pynwb.behavior.SpatialSeries, + velocity: pynwb.TimeSeries = None, + **kwargs): + + super().__init__() + + self.units = spatial_series.get_ancestor('NWBFile').units + index = np.arange(1, len(self.units)) + + self.pos_tt = get_timeseries_tt(spatial_series) + if velocity is not None: + self.velocity = velocity + else: + self.velocity = None + + self.pos, self.unit = get_timeseries_in_units(spatial_series) + + self.pixel_width = (np.nanmax(self.pos) - np.nanmin(self.pos)) / 1000 + + style = {'description_width': 'initial'} + bft_gaussian = BoundedFloatText(value=0.0557, min=0, max=99999, description='gaussian sd (m)', style=style) + bft_spatial_bin_len = BoundedFloatText(value=0.0168, min=0, max=99999, description='spatial bin length (m)', + style=style) + cb_normalize_select = Checkbox(value=False, description='normalize', indent=False) + cb_collapsed_select = Checkbox(value=False, description='collapsed', indent=False) + sm_unit_select = widgets.SelectMultiple(options=index, + value=[1, 2, 3, 4, 5], rows=20, + description='Select units', disabled=False + ) + + self.controls = dict( + gaussian_sd=bft_gaussian, + spatial_bin_len=bft_spatial_bin_len, + normalize=cb_normalize_select, + collapsed=cb_collapsed_select, + order=sm_unit_select + ) + + out_fig = interactive_output(self.do_1d_rate_map, self.controls) + checkboxes = widgets.HBox([cb_normalize_select, cb_collapsed_select]) + widget_fig = vis2widget(out_fig) + self.children = [widgets.HBox([ + widgets.VBox([ + bft_gaussian, + bft_spatial_bin_len, + checkboxes, + sm_unit_select + ], + layout=Layout(max_width="40%")), + widget_fig], + layout=Layout(width="100%", height="100%")) + ] + + def do_1d_rate_map(self, order=None, normalize=False, collapsed=False, gaussian_sd=0.0557, + spatial_bin_len=0.0168, **kwargs): + tmin = min(self.pos_tt) + tmax = max(self.pos_tt) + index = np.asarray(order) + + for i, ind in enumerate(index): + + all_unit_firing_rate_temp, xx = self.compute_1d_firing_rate(ind, tmin, tmax, gaussian_sd, spatial_bin_len) + if not i: + all_unit_firing_rate = np.zeros([len(index), len(xx)]) + + all_unit_firing_rate[i] = all_unit_firing_rate_temp + + fig, ax = plt.subplots(figsize=(7, 7)) + plot_tuning_curves1D(all_unit_firing_rate, xx, ax=ax, unit_labels=index, normalize=normalize, + collapsed=collapsed) + + return fig + + @lru_cache() + def compute_1d_firing_rate(self, ind, tmin, tmax, gaussian_sd, spatial_bin_len): + spikes = get_spike_times(self.units, ind, [tmin, tmax]) + xx, _, all_unit_firing_rate_temp = compute_linear_firing_rate(self.pos, self.pos_tt, spikes, + gaussian_sd=gaussian_sd, + spatial_bin_len=spatial_bin_len, + velocity=self.velocity) + return all_unit_firing_rate_temp, xx + + +def plot_tuning_curves1D(ratemap, bin_pos, ax=None, normalize=False, pad=10, unit_labels=None, fill=True, color=None, + collapsed=False): + """ + + Parameters + ---------- + ratemap: array-like + An array of dim: [number of units, bin positions] with the spike rates for a unit, at every pos, in each row + bin_pos: array-like + An array representing the bin positions of ratemap for each column + ax: matplotlib.pyplot.Axes + Axes object for the figure on which the ratemaps will be plotted + normalize: bool + default = False + Input to determine whether or not to normalize firing rates + pad: int + default = 10 + Changes to 0 if 'collapsed' is true + Amount of space to put between each unit (i.e. row) in the figure + unit_labels: array-like + Unit ids for each unit in ratemap + collapsed: bool + default = False + Determines whether to plot the ratemaps with zero padding, i.e. at the same y coordinate, on the ratemap + fill: bool, optional + + Returns + ------- + matplotlib.pyplot.Axes + + """ + xmin = bin_pos[0] + xmax = bin_pos[-1] + xvals = bin_pos + + n_units, n_ext = ratemap.shape + + if normalize: + peak_firing_rates = ratemap.max(axis=1) + ratemap = (ratemap.T / peak_firing_rates).T + pad = 1 + + if collapsed: + pad = 0 + + if xvals is None: + xvals = np.arange(n_ext) + if xmin is None: + xmin = xvals[0] + if xmax is None: + xmax = xvals[-1] + + for unit, curve in enumerate(ratemap): + if color is None: + line = ax.plot(xvals, unit * pad + curve, zorder=int(10 + 2 * n_units - 2 * unit)) + else: + line = ax.plot(xvals, unit * pad + curve, zorder=int(10 + 2 * n_units - 2 * unit), color=color) + if fill: + # Get the color from the current curve + fillcolor = line[0].get_color() + ax.fill_between(xvals, unit * pad, unit * pad + curve, alpha=0.3, color=fillcolor, + zorder=int(10 + 2 * n_units - 2 * unit - 1)) + + ax.set_xlim(xmin, xmax) + if pad != 0: + yticks = np.arange(n_units) * pad + 0.5 * pad + ax.set_yticks(yticks) + ax.set_yticklabels(unit_labels) + ax.set_xlabel('external variable') + ax.set_ylabel('unit') + ax.tick_params(axis=u'y', which=u'both', length=0) + ax.spines['left'].set_color('none') + ax.yaxis.set_ticks_position('right') + else: + ax.set_ylim(0) + if normalize: + ax.set_ylabel('normalized firing rate') + else: + ax.set_ylabel('firing rate [Hz]') + + ax.spines['top'].set_color('none') + ax.xaxis.set_ticks_position('bottom') + ax.spines['right'].set_color('none') + ax.yaxis.set_ticks_position('left') + + return ax diff --git a/nwbwidgets/view.py b/nwbwidgets/view.py index e78d7f9a..0f272377 100644 --- a/nwbwidgets/view.py +++ b/nwbwidgets/view.py @@ -20,6 +20,7 @@ timeseries, file, spectrum, + placefield ) @@ -58,12 +59,11 @@ def show_dynamic_table(node, **kwargs) -> widgets.Widget: pynwb.ProcessingModule: base.processing_module, hdmf.common.DynamicTable: show_dynamic_table, pynwb.ecephys.ElectricalSeries: ecephys.ElectricalSeriesWidget, - pynwb.behavior.SpatialSeries: OrderedDict( - { - "over time": timeseries.SeparateTracesPlotlyWidget, - "trace": behavior.plotly_show_spatial_trace, - } - ), + pynwb.behavior.Position: behavior.show_position, + pynwb.behavior.SpatialSeries: OrderedDict({ + 'over time': timeseries.SeparateTracesPlotlyWidget, + 'trace': behavior.plotly_show_spatial_trace, + 'rate map': placefield.route_placefield}), pynwb.image.GrayscaleImage: image.show_grayscale_image, pynwb.image.RGBImage: image.show_rbga_image, pynwb.image.RGBAImage: image.show_rbga_image,