diff --git a/driftmapviewer.py b/driftmapviewer.py deleted file mode 100644 index a2c5a6f..0000000 --- a/driftmapviewer.py +++ /dev/null @@ -1,521 +0,0 @@ -from pathlib import Path -from .base import BaseWidget, to_attr -import matplotlib.axis -import scipy.signal -from spikeinterface.core import read_python -import numpy as np -import pandas as pd - -import matplotlib.pyplot as plt -from scipy import stats - - -class KilosortDriftMapWidget(BaseWidget): - """ - Create a drift map plot in the kilosort style. This is ported from Nick Steinmetz's - `spikes` repository MATLAB code, https://github.com/cortex-lab/spikes. - - By default, a raster plot is drawn with the y-axis is spike depth and - x-axis is time. Optionally, a corresponding 2D activity histogram can be - added as a subplot (spatial bins, spike counts) with optional - peak coloring and drift event detection (see below). - - Parameters - ---------- - sorter_output : str | Path, - Path to the kilosort output folder. - only_include_large_amplitude_spikes : bool - If `True`, only spikes with larger amplitudes are included. For - details, see `_filter_large_amplitude_spikes()`. - decimate : None | int - If an integer n, only every nth spike is kept from the plot. Useful for improving - performance when there are many spikes. If `None`, spikes will not be decimated. - add_histogram_plot : bool - If `True`, an activity histogram will be added to a new subplot to the - left of the drift map. - add_histogram_peaks_and_boundaries : bool - If `True`, activity histogram peaks are detected and colored red if - isolated according to start/end boundaries of the peak (blue otherwise). - add_drift_events : bool - If `True`, drift events will be plot on the raster map. Required - `add_histogram_plot` and `add_histogram_peaks_and_boundaries` to run. - weight_histogram_by_amplitude : bool - If `True`, histogram counts will be weighted by spike amplitude. - localised_spikes_only : bool - If `True`, only spatially isolated spikes will be included. - exclude_noise : bool - If `True`, units labelled as noise in the `cluster_groups` file - will be excluded. - gain : float | None - If not `None`, amplitudes will be scaled by the supplied gain. - large_amplitude_only_segment_size: float - If `only_include_large_amplitude_spikes` is `True`, the probe is split into - segments to compute mean and std used as threshold. This sets the size of the - segments in um. - localised_spikes_channel_cutoff: int - If `localised_spikes_only` is `True`, spikes that have more than half of the - maximum loading channel over a range of > n channels are removed. - This sets the number of channels. - """ - - def __init__( - self, - sorter_output: str | Path, - only_include_large_amplitude_spikes: bool = True, - decimate: None | int = None, - add_histogram_plot: bool = False, - add_histogram_peaks_and_boundaries: bool = True, - add_drift_events: bool = True, - weight_histogram_by_amplitude: bool = False, - localised_spikes_only: bool = False, - exclude_noise: bool = False, - gain: float | None = None, - large_amplitude_only_segment_size: float = 800.0, - localised_spikes_channel_cutoff: int = 20, - ): - if not isinstance(sorter_output, Path): - sorter_output = Path(sorter_output) - - if not sorter_output.is_dir(): - raise ValueError(f"No output folder found at {sorter_output}") - - if not (sorter_output / "params.py").is_file(): - raise ValueError( - "The `sorting_output` path is not a valid kilosort output" - "folder. It does not contain a `params.py` file`." - ) - - plot_data = dict( - sorter_output=sorter_output, - only_include_large_amplitude_spikes=only_include_large_amplitude_spikes, - decimate=decimate, - add_histogram_plot=add_histogram_plot, - add_histogram_peaks_and_boundaries=add_histogram_peaks_and_boundaries, - add_drift_events=add_drift_events, - weight_histogram_by_amplitude=weight_histogram_by_amplitude, - localised_spikes_only=localised_spikes_only, - exclude_noise=exclude_noise, - gain=gain, - large_amplitude_only_segment_size=large_amplitude_only_segment_size, - localised_spikes_channel_cutoff=localised_spikes_channel_cutoff, - ) - BaseWidget.__init__(self, plot_data, backend="matplotlib") - - def plot_matplotlib(self, data_plot: dict, **unused_kwargs) -> None: - - dp = to_attr(data_plot) - - spike_times, spike_amplitudes, spike_depths, _ = self._compute_spike_amplitude_and_depth( - dp.sorter_output, dp.localised_spikes_only, dp.exclude_noise, dp.gain, dp.localised_spikes_channel_cutoff - ) - - # Calculate the amplitude range for plotting first, so the scale is always the - # same across all options (e.g. decimation) which helps with interpretability. - if dp.only_include_large_amplitude_spikes: - amplitude_range_all_spikes = ( - spike_amplitudes.min(), - spike_amplitudes.max(), - ) - else: - amplitude_range_all_spikes = np.percentile(spike_amplitudes, (1, 90)) - - if dp.decimate: - spike_times = spike_times[:: dp.decimate] - spike_amplitudes = spike_amplitudes[:: dp.decimate] - spike_depths = spike_depths[:: dp.decimate] - - if dp.only_include_large_amplitude_spikes: - spike_times, spike_amplitudes, spike_depths = self._filter_large_amplitude_spikes( - spike_times, spike_amplitudes, spike_depths, dp.large_amplitude_only_segment_size - ) - - # Setup axis and plot the raster drift map - fig = plt.figure(figsize=(10, 10 * (6 / 8))) - - if dp.add_histogram_plot: - gs = fig.add_gridspec(1, 2, width_ratios=[1, 5]) - hist_axis = fig.add_subplot(gs[0]) - raster_axis = fig.add_subplot(gs[1], sharey=hist_axis) - else: - raster_axis = fig.add_subplot() - - self._plot_kilosort_drift_map_raster( - spike_times, - spike_amplitudes, - spike_depths, - amplitude_range_all_spikes, - axis=raster_axis, - ) - - if not dp.add_histogram_plot: - raster_axis.set_xlabel("time") - raster_axis.set_ylabel("y position") - self.axes = [raster_axis] - return - - # If the histogram plot is requested, plot it alongside - # it's peak colouring, bounds display and drift point display. - hist_axis.set_xlabel("count") - raster_axis.set_xlabel("time") - hist_axis.set_ylabel("y position") - - bin_centers, counts = self._compute_activity_histogram( - spike_amplitudes, spike_depths, dp.weight_histogram_by_amplitude - ) - hist_axis.plot(counts, bin_centers, color="black", linewidth=1) - - if dp.add_histogram_peaks_and_boundaries: - drift_events = self._color_histogram_peaks_and_detect_drift_events( - spike_times, spike_depths, counts, bin_centers, hist_axis - ) - - if dp.add_drift_events and np.any(drift_events): - raster_axis.scatter(drift_events[:, 0], drift_events[:, 1], facecolors="r", edgecolors="none") - for i, _ in enumerate(drift_events): - raster_axis.text( - drift_events[i, 0] + 1, drift_events[i, 1], str(np.round(drift_events[i, 2])), color="r" - ) - self.axes = [hist_axis, raster_axis] - - def _plot_kilosort_drift_map_raster( - self, - spike_times: np.ndarray, - spike_amplitudes: np.ndarray, - spike_depths: np.ndarray, - amplitude_range: np.ndarray | tuple, - axis: matplotlib.axes.Axes, - ) -> None: - """ - Plot a drift raster plot in the kilosort style. - - This function was ported from Nick Steinmetz's `spikes` repository - MATLAB code, https://github.com/cortex-lab/spikes - - Parameters - ---------- - spike_times : np.ndarray - (num_spikes,) array of spike times. - spike_amplitudes : np.ndarray - (num_spikes,) array of corresponding spike amplitudes. - spike_depths : np.ndarray - (num_spikes,) array of corresponding spike depths. - amplitude_range : np.ndarray | tuple - (2,) array of min, max amplitude values for color binning. - axis : matplotlib.axes.Axes - Matplotlib axes object on which to plot the drift map. - """ - n_color_bins = 20 - marker_size = 0.5 - - color_bins = np.linspace(amplitude_range[0], amplitude_range[1], n_color_bins) - - colors = plt.get_cmap("gray")(np.linspace(0, 1, n_color_bins))[::-1] - - for bin_idx in range(n_color_bins - 1): - - spikes_in_amplitude_bin = np.logical_and( - spike_amplitudes >= color_bins[bin_idx], spike_amplitudes <= color_bins[bin_idx + 1] - ) - axis.scatter( - spike_times[spikes_in_amplitude_bin], - spike_depths[spikes_in_amplitude_bin], - color=colors[bin_idx], - s=marker_size, - antialiased=True, - ) - - def _compute_activity_histogram( - self, spike_amplitudes: np.ndarray, spike_depths: np.ndarray, weight_histogram_by_amplitude: bool - ) -> tuple[np.ndarray, ...]: - """ - Compute the activity histogram for the kilosort drift map's left-side plot. - - Parameters - ---------- - spike_amplitudes : np.ndarray - (num_spikes,) array of spike amplitudes. - spike_depths : np.ndarray - (num_spikes,) array of spike depths. - weight_histogram_by_amplitude : bool - If `True`, the spike amplitudes are taken into consideration when generating the - histogram. The amplitudes are scaled to the range [0, 1] then summed for each bin, - to generate the histogram values. If `False`, counts (i.e. num spikes per bin) - are used. - - Returns - ------- - bin_centers : np.ndarray - The spatial bin centers (probe depth) for the histogram. - values : np.ndarray - The histogram values. If `weight_histogram_by_amplitude` is `False`, these - values represent are counts, otherwise they are counts weighted by amplitude. - """ - assert ( - spike_amplitudes.dtype == np.float64 - ), "`spike amplitudes should be high precision as many values are summed." - - bin_um = 2 - bins = np.arange(spike_depths.min() - bin_um, spike_depths.max() + bin_um, bin_um) - values, bins = np.histogram(spike_depths, bins=bins) - bin_centers = (bins[:-1] + bins[1:]) / 2 - - if weight_histogram_by_amplitude: - bin_indices = np.digitize(spike_depths, bins, right=True) - 1 - values = np.zeros(bin_indices.max() + 1, dtype=np.float64) - scaled_spike_amplitudes = (spike_amplitudes - spike_amplitudes.min()) / np.ptp(spike_amplitudes) - np.add.at(values, bin_indices, scaled_spike_amplitudes) - - return bin_centers, values - - def _color_histogram_peaks_and_detect_drift_events( - self, - spike_times: np.ndarray, - spike_depths: np.ndarray, - counts: np.ndarray, - bin_centers: np.ndarray, - hist_axis: matplotlib.axes.Axes, - ) -> np.ndarray: - """ - Given an activity histogram, color the peaks red (isolated peak) or - blue (peak overlaps with other peaks) and compute spatial drift - events for isolated peaks across time bins. - - This function was ported from Nick Steinmetz's `spikes` repository - MATLAB code, https://github.com/cortex-lab/spikes - - Parameters - ---------- - spike_times : np.ndarray - (num_spikes,) array of spike times. - spike_depths : np.ndarray - (num_spikes,) array of corresponding spike depths. - counts : np.ndarray - (num_bins,) array of histogram bin counts. - bin_centers : np.ndarray - (num_bins,) array of histogram bin centers. - hist_axis : matplotlib.axes.Axes - Axes on which the histogram is plot, to add peaks. - - Returns - ------- - drift_events : np.ndarray - A (num_drift_events, 3) array of drift events. The columns are - (time_position, spatial_position, drift_value). The drift - value is computed per time, spatial bin as the difference between - the median position of spikes in the bin, and the bin center. - """ - all_peak_indexes = scipy.signal.find_peaks( - counts, - )[0] - - # Filter low-frequency peaks, so they are not included in the - # step to determine whether peaks are overlapping (new step - # introduced in the port to python) - bin_above_freq_threshold = counts[all_peak_indexes] > 0.3 * spike_times[-1] - filtered_peak_indexes = all_peak_indexes[bin_above_freq_threshold] - - drift_events = [] - for idx, peak_index in enumerate(filtered_peak_indexes): - - peak_count = counts[peak_index] - - # Find the start and end of peak min/max bounds (5% of amplitude) - start_position = np.where(counts[:peak_index] < peak_count * 0.05)[0].max() - end_position = np.where(counts[peak_index:] < peak_count * 0.05)[0].min() + peak_index - - if ( # bounds include another, different histogram peak - idx > 0 - and start_position < filtered_peak_indexes[idx - 1] - or idx < filtered_peak_indexes.size - 1 - and end_position > filtered_peak_indexes[idx + 1] - ): - hist_axis.scatter(peak_count, bin_centers[peak_index], facecolors="none", edgecolors="blue") - continue - - else: - for position in [start_position, end_position]: - hist_axis.axhline(bin_centers[position], 0, counts.max(), color="grey", linestyle="--") - hist_axis.scatter(peak_count, bin_centers[peak_index], facecolors="none", edgecolors="red") - - # For isolated histogram peaks, detect the drift events, defined as - # difference between spatial bin center and median spike depth in the bin - # over 6 um (in time / spatial bins with at least 10 spikes). - depth_in_window = np.logical_and( - spike_depths > bin_centers[start_position], - spike_depths < bin_centers[end_position], - ) - current_spike_depths = spike_depths[depth_in_window] - current_spike_times = spike_times[depth_in_window] - - window_s = 10 - - all_time_bins = np.arange(0, np.ceil(spike_times[-1]).astype(int), window_s) - for time_bin in all_time_bins: - - spike_in_time_bin = np.logical_and( - current_spike_times >= time_bin, current_spike_times <= time_bin + window_s - ) - drift_size = bin_centers[peak_index] - np.median(current_spike_depths[spike_in_time_bin]) - - # 6 um is the hardcoded threshold for drift, and we want at least 10 spikes for the median calculation - bin_has_drift = np.abs(drift_size) > 6 and np.sum(spike_in_time_bin, dtype=np.int16) > 10 - if bin_has_drift: - drift_events.append((time_bin + window_s / 2, bin_centers[peak_index], drift_size)) - - drift_events = np.array(drift_events) - - return drift_events - - def _compute_spike_amplitude_and_depth( - self, - sorter_output: str | Path, - localised_spikes_only, - exclude_noise, - gain: float | None, - localised_spikes_channel_cutoff: int, - ) -> tuple[np.ndarray, ...]: - """ - Compute the amplitude and depth of all detected spikes from the kilosort output. - - This function was ported from Nick Steinmetz's `spikes` repository - MATLAB code, https://github.com/cortex-lab/spikes - - Parameters - ---------- - sorter_output : str | Path - Path to the kilosort run sorting output. - localised_spikes_only : bool - If `True`, only spikes with small spatial footprint (i.e. 20 channels within 1/2 of the - amplitude of the maximum loading channel) and which are close to the average depth for - the cluster are returned. - gain: float | None - If a float provided, the `spike_amplitudes` will be scaled by this gain. - localised_spikes_channel_cutoff : int - If `localised_spikes_only` is `True`, spikes that have less than half of the - maximum loading channel over a range of n channels are removed. - This sets the number of channels. - - Returns - ------- - spike_times : np.ndarray - (num_spikes,) array of spike times. - spike_amplitudes : np.ndarray - (num_spikes,) array of corresponding spike amplitudes. - spike_depths : np.ndarray - (num_spikes,) array of corresponding depths (probe y-axis location). - - Notes - ----- - In `_template_positions_amplitudes` spike depths is calculated as simply the template - depth, for each spike (so it is the same for all spikes in a cluster). Here we need - to find the depth of each individual spike, using its low-dimensional projection. - `pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike. - Taking the first component, the subset of 32 channels associated with this - spike are indexed to get the actual channel locations (in um). Then, the channel - locations are weighted by their PC values. - """ - if isinstance(sorter_output, str): - sorter_output = Path(sorter_output) - - params = self._load_ks_dir(sorter_output, load_pcs=True, exclude_noise=exclude_noise) - - if localised_spikes_only: - localised_templates = [] - - for idx, template in enumerate(params["templates"]): - max_channel = np.max(np.abs(params["templates"][idx, :, :])) - channels_over_threshold = np.max(np.abs(params["templates"][idx, :, :]), axis=0) > 0.5 * max_channel - channel_ids_over_threshold = np.where(channels_over_threshold)[0] - - if np.ptp(channel_ids_over_threshold) <= localised_spikes_channel_cutoff: - localised_templates.append(idx) - - localised_template_by_spike = np.isin(params["spike_templates"], localised_templates) - - params["spike_templates"] = params["spike_templates"][localised_template_by_spike] - params["spike_times"] = params["spike_times"][localised_template_by_spike] - params["spike_clusters"] = params["spike_clusters"][localised_template_by_spike] - params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][localised_template_by_spike] - params["pc_features"] = params["pc_features"][localised_template_by_spike] - - # Compute spike depths - pc_features = params["pc_features"][:, 0, :] - pc_features[pc_features < 0] = 0 - - # Get the channel indexes corresponding to the 32 channels from the PC. - spike_features_indices = params["pc_features_indices"][params["spike_templates"], :] - - ycoords = params["channel_positions"][:, 1] - spike_feature_ycoords = ycoords[spike_features_indices] - - spike_depths = np.sum(spike_feature_ycoords * pc_features**2, axis=1) / np.sum(pc_features**2, axis=1) - - # Compute amplitudes, scale if required and drop un-localised spikes before returning. - spike_amplitudes, _, _, _, unwhite_templates, *_ = self._template_positions_amplitudes( - params["templates"], - params["whitening_matrix_inv"], - ycoords, - params["spike_templates"], - params["temp_scaling_amplitudes"], - ) - - if gain is not None: - spike_amplitudes *= gain - - max_site = np.argmax(np.max(np.abs(unwhite_templates), axis=1), axis=1) - spike_sites = max_site[params["spike_templates"]] - - if localised_spikes_only: - # Interpolate the channel ids to location. - # Remove spikes > 5 um from average position - # Above we already removed non-localized templates, but that on its own is insufficient. - # Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient - b = stats.linregress(spike_depths, spike_sites).slope - i = np.abs(spike_sites - b * spike_depths) <= 5 - - params["spike_times"] = params["spike_times"][i] - spike_amplitudes = spike_amplitudes[i] - spike_depths = spike_depths[i] - - return params["spike_times"], spike_amplitudes, spike_depths, spike_sites - - def _filter_large_amplitude_spikes( - self, - spike_times: np.ndarray, - spike_amplitudes: np.ndarray, - spike_depths: np.ndarray, - large_amplitude_only_segment_size, - ) -> tuple[np.ndarray, ...]: - """ - Return spike properties with only the largest-amplitude spikes included. The probe - is split into egments, and within each segment the mean and std computed. - Any spike less than 1.5x the standard deviation in amplitude of it's segment is excluded - Splitting the probe is only done for the exclusion step, the returned array are flat. - - Takes as input arrays `spike_times`, `spike_depths` and `spike_amplitudes` and returns - copies of these arrays containing only the large amplitude spikes. - """ - spike_bool = np.zeros_like(spike_amplitudes, dtype=bool) - - segment_size_um = large_amplitude_only_segment_size - probe_segments_left_edges = np.arange(np.floor(spike_depths.max() / segment_size_um) + 1) * segment_size_um - - for segment_left_edge in probe_segments_left_edges: - segment_right_edge = segment_left_edge + segment_size_um - - spikes_in_seg = np.where( - np.logical_and(spike_depths >= segment_left_edge, spike_depths < segment_right_edge) - )[0] - spike_amps_in_seg = spike_amplitudes[spikes_in_seg] - is_high_amplitude = spike_amps_in_seg > np.mean(spike_amps_in_seg) + 1.5 * np.std(spike_amps_in_seg, ddof=1) - - spike_bool[spikes_in_seg] = is_high_amplitude - - spike_times = spike_times[spike_bool] - spike_amplitudes = spike_amplitudes[spike_bool] - spike_depths = spike_depths[spike_bool] - - return spike_times, spike_amplitudes, spike_depths - - - diff --git a/example_playing.py b/example_playing.py new file mode 100644 index 0000000..be72233 --- /dev/null +++ b/example_playing.py @@ -0,0 +1,102 @@ +from mpl_plotting.driftmapviewer_new import get_drift_map_plot, _plot_kilosort_drift_map_raster +import matplotlib.pyplot as plt +from ks_extractors import kilosort1_3 +from ks_extractors import kilosort_4 +from ks_extractors import helpers +from pathlib import Path +import numpy as np +from interactive.driftmap_plot_widget import DriftmapPlotWidget +import numpy as np +import pyqtgraph as pg +from pyqtgraph.Qt import QtWidgets, QtCore +import matplotlib.pyplot as plt + +# TODO +# ---- +# - review driftmap_view and multi_session_drift_map +# - review KS4 and KS3 method, in particular the scaling and whitening. Check ks4 vs. ks25 +# +# +# +# Don't forget to write in docs the caveats on the extraction + + + +# TODO Extra - look into: +# TODO idea: memmap the npy files and decimate ON LOAD +# TODO: KS can wrap template channels around probe boundaries (e.g. channels +# [0,1,2,380,381,382] for a template near the top). We unwrap these in +# _get_nonzero_channel_indices. Post about this on the Kilosort GitHub +# and consider adding an option to disable the unwrap. + +# This makes the assumption that there will never be different .csv and .tsv files +# in the same sorter output (this should never happen, there will never even be two). +# Though can be saved as .tsv, it seems the .csv is also tab formatted as far as pandas is concerned. + +# TODO: this is super weird, can be improved? +# if log_transform_amplitudes: +# spike_amplitudes = np.log(spike_amplitudes) # TODO: give optional (None, 2 or 10) + + + + +# TODO: dont use gain, instead set clim +# TODO: removed localised peaks +# TODO: removed drift event and boundaries + +# TODO: it would be really cool and useful to hover over +# the plot and see the template waveform + +# TODO idea: memmap the npy files and decimate ON LOAD + +# TODO: KS can wrap template channels around probe boundaries (e.g. channels +# [0,1,2,380,381,382] for a template near the top). We unwrap these in +# _get_nonzero_channel_indices. Post about this on the Kilosort GitHub +# and consider adding an option to disable the unwrap. + +from PySide6 import QtWidgets +from interactive.driftmap_view import DriftMapView + +app = QtWidgets.QApplication.instance() or QtWidgets.QApplication([]) + +FILES = [ + r"X:\aeon\dj_store\ephys-processed\social-ephys0.1-aeon3\ephys_blocks\2024-06-04T11-00-00_2024-06-04T14-00-00\0-95\kilosort4_400\spike_sorting\sorter_output", + r"X:\aeon\dj_store\ephys-processed\social-ephys0.1-aeon3\ephys_blocks\2024-06-04T13-00-00_2024-06-04T16-00-00\0-95\kilosort4_400\spike_sorting\sorter_output" +] + +pooled_amplitudes = helpers.get_pooled_amplitudes( + FILES +) + +min_, max_ = np.percentile(pooled_amplitudes, (97, 99)) # TODO: add exclude noise + + +panels = [] +for file in FILES: #[ + # r"Y:\public\projects\BeJG_20230130_VisDetect\wEPhys\BG_046\joe\scratch\derivatives\BG_046_26062025\shank_1\sorting\no_motion\sorter_output", + # r"Y:\public\projects\BeJG_20230130_VisDetect\wEPhys\BG_046\joe\scratch\derivatives\BG_046_27062025\shank_1\sorting\motion\sorter_output", + # r"X:\aeon\dj_store\ephys-processed\social-ephys0.1-aeon3\ephys_blocks\2024-06-04T11-00-00_2024-06-04T14-00-00\0-95\kilosort4_400\spike_sorting\sorter_output", + # r"X:\aeon\dj_store\ephys-processed\social-ephys0.1-aeon3\ephys_blocks\2024-06-04T13-00-00_2024-06-04T16-00-00\0-95\kilosort4_400\spike_sorting\sorter_output", + # r"X:\aeon\dj_store\ephys-processed\social-ephys0.1-aeon3\ephys_blocks\2024-06-04T15-00-00_2024-06-04T18-00-00\0-95\kilosort4_400\spike_sorting\sorter_output" + + +#]: + plotter = DriftMapView( + file + ) + + fig = plotter.drift_map_plot_interactive( + decimate=10, + exclude_noise=True, + amplitude_scaling=(min_, max_), + n_color_bins=25, + filter_amplitude_mode="absolute", # "percentile", + filter_amplitude_values=(min_, max_) + ) + + panels.append(fig) + +from interactive.multi_session_drift_map import MultiSessionDriftmapWidget +multi = MultiSessionDriftmapWidget(panels) + +app.exec() diff --git a/interactive/__init__.py b/interactive/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/interactive/driftmap_plot_widget.py b/interactive/driftmap_plot_widget.py new file mode 100644 index 0000000..141b2b7 --- /dev/null +++ b/interactive/driftmap_plot_widget.py @@ -0,0 +1,580 @@ +import numpy as np +import pyqtgraph as pg +from pyqtgraph.Qt import QtWidgets, QtCore +import matplotlib.pyplot as plt + +pg.setConfigOption("background", "w") +pg.setConfigOption("foreground", "k") +pg.setConfigOption("antialias", True) + + +class DriftmapPlotWidget(QtWidgets.QWidget): + """ + """ + def __init__(self, spike_times, spike_amplitudes, spike_depths, + spike_templates, templates, channel_positions, + amplitude_scaling="linear", n_color_bins=20, + point_size=5.0, sorter_path=None): + super().__init__() + + print(f"Loaded {spike_times.size} spikes from {sorter_path}") + + self.spike_times = spike_times + self.spike_amplitudes = spike_amplitudes + self.spike_depths = spike_depths + self.spike_templates = spike_templates + self.templates = templates + self.channel_positions = channel_positions + + self.cfgs = { + "right_panel_view_mode": "heatmap", + "left_panel_y_axis": { + "on": False, + "y_max": 200, + "y_min": -200, + }, + } + + self.selected_spot = None + self._trace_view_initialized = False + + self.resize(1400, 820) + + # Instantiate UI and scatter plot + win_left, win_right = self._init_ui() + self._init_scatter_plot(win_left, spike_times, spike_amplitudes, + spike_depths, amplitude_scaling, + n_color_bins, point_size) + self._init_panel_plot(win_right) + + # Connect widgets + self.ymin_spin.valueChanged.connect(self.handle_y_spinbox_min) + self.ymax_spin.valueChanged.connect(self.handle_y_spinbox_max) + self._fix_limits_cb.toggled.connect(self.handle_fix_ylim_cb) + self.scatter.sigClicked.connect(self.handle_click) + self._view_radio_group.idToggled.connect(self.handle_view_radio_toggled) + + def _init_scatter_plot(self, win_left, spike_times, spike_amplitudes, + spike_depths, amplitude_scaling, n_color_bins, + point_size): + """Create the scatter plot on the left panel. + + Parameters + ---------- + win_left : pg.GraphicsLayoutWidget + The left graphics area to host the scatter plot. + spike_times, spike_amplitudes, spike_depths : np.ndarray + Spike data arrays. + amplitude_scaling : str | tuple + Colour-scaling mode or explicit (min, max) range. + n_color_bins : int + Number of grey-scale colour bins. + point_size : float + Scatter-point diameter in pixels. + """ + self.p_scatter = win_left.addPlot(row=0, col=0) + self.p_scatter.setLabel("bottom", "Time (s)") + self.p_scatter.setLabel("left", "Depth (µm)") + self.p_scatter.showGrid(x=False, y=False) + + # set amplitude colors + rgba_colors = self._compute_amplitude_colors( + spike_amplitudes, amplitude_scaling, n_color_bins + ) + + # set axis limits, pad around them slightly + x_pad = (spike_times.max() - spike_times.min()) * 0.025 + y_pad = (spike_depths.max() - spike_depths.min()) * 0.05 + + self.p_scatter.getViewBox().setLimits( + xMin=spike_times.min() - x_pad, + xMax=spike_times.max() + x_pad, + yMin=spike_depths.min() - y_pad, + yMax=spike_depths.max() + y_pad, + ) + + # create plot — each point stores its spike index in 'data' for click/tooltip lookup + self.scatter = pg.ScatterPlotItem( + spike_times, + spike_depths, + pxMode=True, + size=point_size, + hoverable=True, + antialias=True, + data=np.arange(spike_times.size), + brush=rgba_colors, + pen=None, + tip=lambda x, y, data: ( + f"x={x:.3f}\ny={y:.1f}\n" + f"amp={self.spike_amplitudes[int(data)]:.2f}" + ), + ) + self.p_scatter.addItem(self.scatter) + + def _init_panel_plot(self, win_right): + """Create the template panel plot on the right side. + + Parameters + ---------- + win_right : pg.GraphicsLayoutWidget + The right graphics area to host the panel plot. + """ + self.panel_plot = win_right.addPlot(row=0, col=0) + self.panel_plot.setLabel("bottom", "sample") + self.panel_plot.setLabel("left", "amplitude") + self.panel_plot.showGrid(x=False, y=False) + + def _connect_signals(self): + """Wire up Qt signal/slot connections.""" + self.ymin_spin.valueChanged.connect(self.handle_y_spinbox_min) + self.ymax_spin.valueChanged.connect(self.handle_y_spinbox_max) + self._fix_limits_cb.toggled.connect(self.handle_fix_ylim_cb) + self.scatter.sigClicked.connect(self.handle_click) + self._view_radio_group.idToggled.connect(self.handle_view_radio_toggled) + + def handle_view_radio_toggled(self, button_id, checked): + if not checked: + return + + mode_map = {0: "max_waveform", 1: "heatmap", 2: "heatmap_all_channels", 3: "trace_view"} + mode = mode_map[button_id] + self.cfgs["right_panel_view_mode"] = mode + self._trace_view_initialized = False + + if mode == "trace_view": + self._limits_page.setVisible(False) + else: + self._limits_page.setVisible(True) + if mode == "max_waveform": + self._fix_limits_cb.setText("Fix y-limits") + self._min_label.setText("Y min:") + self._max_label.setText("Y max:") + else: + self._fix_limits_cb.setText("Fix color limits") + self._min_label.setText("C min:") + self._max_label.setText("C max:") + + if self.selected_spot is not None: + self.set_y_limit() + self.update_panel(int(self.selected_spot.data())) + + def handle_y_spinbox_min(self, value): + self.cfgs["left_panel_y_axis"]["y_min"] = value + if self.selected_spot is not None: + self.set_y_limit() + self.update_panel(int(self.selected_spot.data())) + + def handle_y_spinbox_max(self, value): + self.cfgs["left_panel_y_axis"]["y_max"] = value + if self.selected_spot is not None: + self.set_y_limit() + self.update_panel(int(self.selected_spot.data())) + + def handle_fix_ylim_cb(self, active): + self.ymin_spin.setEnabled(active) + self.ymax_spin.setEnabled(active) + self.cfgs["left_panel_y_axis"]["on"] = active + self.set_y_limit() + if self.selected_spot is None: + return + self.update_panel(int(self.selected_spot.data())) + + def set_y_limit(self): + mode = self.cfgs["right_panel_view_mode"] + if mode == "max_waveform": + if self.cfgs["left_panel_y_axis"]["on"]: + self.panel_plot.setYRange(self.ymin_spin.value(), self.ymax_spin.value(), padding=0) + else: + self.panel_plot.enableAutoRange(axis='y') + elif mode in ("heatmap", "heatmap_all_channels"): + pass # color limits applied during draw + else: + self.panel_plot.enableAutoRange() + + def handle_click(self, _, points, __): + if points is None or len(points) <= 0: + return + + spot = points[0] + + if self.selected_spot is not None: + self.selected_spot.setPen(pg.mkPen(None)) + + spot.setPen(pg.mkPen('r', width=2)) + self.selected_spot = spot + + idx = int(spot.data()) + self.update_panel(idx) + + def update_panel(self, spike_idx): + template_id = int(self.spike_templates[spike_idx]) + self.panel_plot.setTitle(f"Template {template_id}") + + if self.cfgs["right_panel_view_mode"] == "max_waveform": + self._draw_max_waveform_on_panel(spike_idx) + elif self.cfgs["right_panel_view_mode"] == "trace_view": + self._draw_template_trace_view_on_panel(spike_idx) + else: + self._draw_template_heatmap_on_panel(spike_idx) + + # TODO: carefully check these!!! + + def _draw_max_waveform_on_panel(self, spike_index): + template_waveform = self.get_max_waveform_data(spike_index) + n_samples = template_waveform.size + + pen = pg.mkPen("k", width=2.5) + self.panel_plot.clear() + self.panel_plot.plot(np.arange(n_samples), template_waveform, pen=pen) + + self.panel_plot.setLabel("bottom", "sample") + self.panel_plot.setLabel("left", "amplitude") + self.panel_plot.getAxis("left").setTicks(None) + self.panel_plot.getAxis("left").setStyle(showValues=True) + self.panel_plot.setXRange(0, n_samples, padding=0.05) + + def _draw_template_heatmap_on_panel(self, spike_index): + template_waveform_2d = self.get_heatmap_data(spike_index) + n_samples, n_chans = template_waveform_2d.shape[0], template_waveform_2d.shape[1] + + self.panel_plot.clear() + + if self.cfgs["right_panel_view_mode"] == "heatmap_all_channels": + self.panel_plot.setLabel("left", "channel") + self.panel_plot.getAxis("left").setTicks(None) + self.panel_plot.getAxis("left").setStyle(showValues=True) + else: + self.panel_plot.setLabel("left", "") + self.panel_plot.getAxis("left").setTicks([]) + self.panel_plot.getAxis("left").setStyle(showValues=False) + self.panel_plot.setLabel("bottom", "sample") + + image_item = pg.ImageItem() + self.panel_plot.addItem(image_item) + + colors = [ + (0, 0, 180, 255), + (255, 255, 255, 255), + (180, 0, 0, 255), + ] + cmap = pg.ColorMap(pos=[0.0, 0.5, 1.0], color=colors) + image_item.setColorMap(cmap) + + if self.cfgs["left_panel_y_axis"]["on"]: + image_item.setLevels(( + self.ymin_spin.value(), + self.ymax_spin.value(), + )) + image_item.setImage(template_waveform_2d) + image_item.setRect(0, 0, n_samples, n_chans) + self.panel_plot.setXRange(0, n_samples, padding=0.05) + self.panel_plot.setYRange(0, n_chans, padding=0.05) + + def _draw_template_trace_view_on_panel(self, spike_index): + template_idx = self.spike_templates[spike_index] + wv = self.templates[template_idx].copy() * self.spike_amplitudes[spike_index] + n_samples, n_chan = wv.shape + + # Use only channels with non-zero data, unwrapping if KS wrapped them + contains_data_idx, is_wrapped = self._get_nonzero_channel_indices(wv) + wv = wv[:, contains_data_idx] + xc = self.channel_positions[contains_data_idx, 0] + yc = self.channel_positions[contains_data_idx, 1] + + if is_wrapped: + xc, yc = self._make_positions_contiguous(xc, yc) + + chan_spacing, x_spacing = self._compute_trace_spacing(xc, yc) + amp_scale = (chan_spacing * 0.45) / np.max(np.abs(wv)) if np.max(np.abs(wv)) > 0 else 1.0 + + self._plot_traces(wv, xc, yc, n_samples, x_spacing, amp_scale) + self._set_trace_view_range(xc, yc, chan_spacing, x_spacing) + + @staticmethod + def _compute_trace_spacing(xc, yc): + """Compute channel and x spacing from probe positions. + + Returns + ------- + chan_spacing : float + Minimum y-distance between adjacent channels. + x_spacing : float + Minimum x-distance between adjacent channel columns. + """ + unique_y = np.unique(yc) + chan_spacing = np.min(np.diff(np.sort(unique_y))) if len(unique_y) > 1 else 1.0 + + unique_x = np.unique(xc) + x_spacing = np.min(np.diff(np.sort(unique_x))) if len(unique_x) > 1 else 20.0 + + return chan_spacing, x_spacing + + def _plot_traces(self, wv, xc, yc, n_samples, x_spacing, amp_scale): + """Draw one waveform trace per channel on the panel plot.""" + self.panel_plot.clear() + + # Time axis in physical units, scaled to fit within ~90% of x_spacing + t = np.arange(-n_samples // 2, n_samples // 2, 1, dtype=np.float32) + t *= (x_spacing * 0.9) / n_samples + + for ii, (xi, yi) in enumerate(zip(xc, yc)): + self.panel_plot.plot( + xi + t, yi + wv[:, ii] * amp_scale, + pen=pg.mkPen('k', width=0.5), + ) + + self.panel_plot.setLabel("bottom", "x position") + self.panel_plot.setLabel("left", "y position (\u00b5m)") + self.panel_plot.getAxis("left").setTicks(None) + self.panel_plot.getAxis("left").setStyle(showValues=True) + + def _set_trace_view_range(self, xc, yc, chan_spacing, x_spacing): + """Set the view range for the trace view, centering on first click + and preserving zoom on subsequent clicks.""" + y_center = (yc.min() + yc.max()) / 2 + x_center = (xc.min() + xc.max()) / 2 + + if not self._trace_view_initialized: + y_pad = (yc.max() - yc.min()) * 0.15 + chan_spacing + x_pad = (xc.max() - xc.min()) * 0.15 + x_spacing + self.panel_plot.setXRange(xc.min() - x_pad, xc.max() + x_pad, padding=0) + self.panel_plot.setYRange(yc.min() - y_pad, yc.max() + y_pad, padding=0) + self._trace_view_initialized = True + else: + view_box = self.panel_plot.getViewBox() + [[x_lo, x_hi], [y_lo, y_hi]] = view_box.viewRange() + y_half = (y_hi - y_lo) / 2 + x_half = (x_hi - x_lo) / 2 + self.panel_plot.setXRange(x_center - x_half, x_center + x_half, padding=0) + self.panel_plot.setYRange(y_center - y_half, y_center + y_half, padding=0) + + def get_max_waveform_data(self, spike_index): + template_idx = self.spike_templates[spike_index] + scaled_template = self.templates[template_idx, :, :] * self.spike_amplitudes[spike_index] + peak_ch = np.argmax(np.max(np.abs(scaled_template), axis=0)) + return scaled_template[:, peak_ch] + + @staticmethod + def _compute_amplitude_colors(spike_amplitudes, amplitude_scaling, n_color_bins): + """Map spike amplitudes to RGBA colours via grey-scale binning. + + Parameters + ---------- + spike_amplitudes : np.ndarray + (num_spikes,) raw amplitude values. + amplitude_scaling : {"linear", "log2", "log10"} | tuple + Scaling mode. A 2-tuple ``(min, max)`` fixes the colour + range explicitly. + n_color_bins : int + Number of grey-scale bins. + + Returns + ------- + np.ndarray + (num_spikes, 4) uint8 RGBA values in [0, 255]. + """ + amp_values = spike_amplitudes.copy() + + if isinstance(amplitude_scaling, tuple): + amp_min, amp_max = amplitude_scaling + elif amplitude_scaling == "log2": + amp_values = np.log2(amp_values) + amp_min, amp_max = amp_values.min(), amp_values.max() + elif amplitude_scaling == "log10": + amp_values = np.log10(amp_values) + amp_min, amp_max = amp_values.min(), amp_values.max() + else: # "linear" + amp_min, amp_max = amp_values.min(), amp_values.max() + + color_bins = np.linspace(amp_min, amp_max, n_color_bins) + gray_colors = plt.get_cmap("gray")(np.linspace(0, 1, n_color_bins))[::-1] + bin_indices = np.clip( + np.searchsorted(color_bins, amp_values, side="right") - 1, + 0, n_color_bins - 2, + ) + return (gray_colors[bin_indices] * 255).astype(np.uint8) + + @staticmethod + def _get_nonzero_channel_indices(scaled_template): + """Get indices of channels with data, unwrapping if KS wrapped them. + + Kilosort can wrap template channel assignments around the probe + boundaries (e.g. channels [0, 1, 2, 380, 381, 382]). This detects + the wrap and reorders so the high-index group comes first, + giving a spatially contiguous result. + + Returns + ------- + contains_data_idx : np.ndarray + Channel indices with data, reordered if wrapping detected. + is_wrapped : bool + True if wrapping was detected and corrected. + """ + # KS templates have zeros on inactive channels; checking row 0 is + # sufficient because active channels always have non-zero values at + # the first sample (the template extends across all samples). + contains_data_idx = np.where(scaled_template[0, :] != 0)[0] + + if len(contains_data_idx) < 2: + return contains_data_idx, False + + # Check for a gap (non-contiguous indices) + diffs = np.diff(contains_data_idx) + gap_positions = np.where(diffs > 1)[0] + + if len(gap_positions) == 1: + # Wrapped: split at gap, put the higher-index group first + split = gap_positions[0] + 1 + contains_data_idx = np.concatenate([ + contains_data_idx[split:], + contains_data_idx[:split], + ]) + return contains_data_idx, True + + return contains_data_idx, False + + @staticmethod + def _make_positions_contiguous(xc, yc): + """Remap channel positions so wrapped channels sit contiguously. + + When KS wraps, channels from the top and bottom of the probe + end up in the same template. We shift the lower-position group + to sit just above the higher-position group (or vice versa) + so they display as one contiguous block. + """ + sorted_y = np.sort(np.unique(yc)) + if len(sorted_y) < 2: + return xc, yc.copy() + + y_diffs = np.diff(sorted_y) + largest_gap_idx = np.argmax(y_diffs) + gap_size = y_diffs[largest_gap_idx] + typical_spacing = np.median(y_diffs) + + if gap_size > typical_spacing * 3: + gap_threshold = sorted_y[largest_gap_idx] + gap_size / 2 + yc_new = yc.copy() + upper_mask = yc >= gap_threshold + shift = yc[upper_mask].min() - yc[~upper_mask].max() - typical_spacing + yc_new[upper_mask] -= shift + return xc, yc_new + + return xc, yc.copy() + + def get_heatmap_data(self, spike_index): + """""" + template_idx = self.spike_templates[spike_index] + scaled_template = self.templates[template_idx, :, :] * self.spike_amplitudes[spike_index] + + if self.cfgs["right_panel_view_mode"] == "heatmap_all_channels": + scaled_template = scaled_template.copy() # TODO: check if this is necessary + scaled_template[:, scaled_template[0, :] == 0] = np.nan + else: + contains_data_idx, _ = self._get_nonzero_channel_indices(scaled_template) + scaled_template = scaled_template[:, contains_data_idx] + + return scaled_template + + + # UI - To possibly be moved to QtDesigner + # ---------------------------------------------------------------------------------- + + def _init_ui(self): + """Build the widget layout: splitter, controls bar, radio buttons, spinboxes. + + Returns + ------- + win_left : pg.GraphicsLayoutWidget + Left graphics area (for the scatter plot). + win_right : pg.GraphicsLayoutWidget + Right graphics area (for the template panel). + """ + # Core layout + outer_layout = QtWidgets.QVBoxLayout(self) + outer_layout.setContentsMargins(0, 0, 0, 0) + outer_layout.setSpacing(0) + + splitter = QtWidgets.QSplitter(QtCore.Qt.Orientation.Horizontal) + splitter.setChildrenCollapsible(False) + outer_layout.addWidget(splitter, stretch=1) + + win_left = pg.GraphicsLayoutWidget() + splitter.addWidget(win_left) + + right_widget = QtWidgets.QWidget() + right_layout = QtWidgets.QVBoxLayout(right_widget) + right_layout.setContentsMargins(0, 0, 0, 0) + right_layout.setSpacing(0) + + win_right = pg.GraphicsLayoutWidget() + right_layout.addWidget(win_right, stretch=1) + splitter.addWidget(right_widget) + total = splitter.width() + splitter.setSizes([int(total * 0.75), int(total * 0.25)]) + + # Controls Bar — below the full splitter + # -------------------------------------------------------------------------------------------------------------- + controls_widget = QtWidgets.QWidget() + controls_layout = QtWidgets.QVBoxLayout(controls_widget) + controls_layout.setContentsMargins(6, 6, 6, 6) + controls_layout.setSpacing(6) + + # --- Radio buttons --- + radio_row = QtWidgets.QWidget() + radio_layout = QtWidgets.QHBoxLayout(radio_row) + radio_layout.setContentsMargins(0, 0, 0, 0) + radio_layout.setSpacing(12) + + self.radio_max_wf = QtWidgets.QRadioButton("Max waveform") + self.radio_heatmap = QtWidgets.QRadioButton("Heatmap") + self.radio_heatmap_all = QtWidgets.QRadioButton("Heatmap (all channels)") + self.radio_trace_view = QtWidgets.QRadioButton("Trace view") + self.radio_heatmap.setChecked(True) + + self._view_radio_group = QtWidgets.QButtonGroup(self) + self._view_radio_group.addButton(self.radio_max_wf, 0) + self._view_radio_group.addButton(self.radio_heatmap, 1) + self._view_radio_group.addButton(self.radio_heatmap_all, 2) + self._view_radio_group.addButton(self.radio_trace_view, 3) + + radio_layout.addWidget(self.radio_max_wf) + radio_layout.addWidget(self.radio_heatmap) + radio_layout.addWidget(self.radio_heatmap_all) + radio_layout.addWidget(self.radio_trace_view) + radio_layout.addStretch() + controls_layout.addWidget(radio_row) + + # Limits controls row + self._limits_page = QtWidgets.QWidget() + limits_layout = QtWidgets.QHBoxLayout(self._limits_page) + limits_layout.setContentsMargins(0, 0, 0, 0) + limits_layout.setSpacing(8) + + self._fix_limits_cb = QtWidgets.QCheckBox("Fix y-limits") + self.ymin_spin = QtWidgets.QDoubleSpinBox() + self.ymax_spin = QtWidgets.QDoubleSpinBox() + for spin in (self.ymin_spin, self.ymax_spin): + spin.setRange(-1e9, 1e9) + spin.setDecimals(1) + spin.setFixedWidth(90) + spin.setMinimumWidth(100) + spin.setButtonSymbols(QtWidgets.QAbstractSpinBox.ButtonSymbols.NoButtons) + spin.setEnabled(self.cfgs["left_panel_y_axis"]["on"]) + self.ymin_spin.setValue(self.cfgs["left_panel_y_axis"]["y_min"]) + self.ymax_spin.setValue(self.cfgs["left_panel_y_axis"]["y_max"]) + + limits_layout.addWidget(self._fix_limits_cb) + self._min_label = QtWidgets.QLabel("Y min:") + self._max_label = QtWidgets.QLabel("Y max:") + limits_layout.addWidget(self._min_label) + limits_layout.addWidget(self.ymin_spin) + limits_layout.addWidget(self._max_label) + limits_layout.addWidget(self.ymax_spin) + limits_layout.addStretch() + + controls_layout.addWidget(self._limits_page) + + # Add controls below the splitter, spanning full width + outer_layout.addWidget(controls_widget) + + return win_left, win_right diff --git a/interactive/driftmap_view.py b/interactive/driftmap_view.py new file mode 100644 index 0000000..af36959 --- /dev/null +++ b/interactive/driftmap_view.py @@ -0,0 +1,277 @@ +from mpl_plotting.driftmapviewer_new import get_drift_map_plot, _plot_kilosort_drift_map_raster +import matplotlib.pyplot as plt +from ks_extractors import kilosort1_3 +from ks_extractors import kilosort_4 +from ks_extractors import helpers +from pathlib import Path +import numpy as np +from .driftmap_plot_widget import DriftmapPlotWidget +import numpy as np +import pyqtgraph as pg +from pyqtgraph.Qt import QtWidgets, QtCore +import matplotlib.pyplot as plt + + +class DriftMapView(): + """Load Kilosort sorter output and provide interactive or static drift map plots. + + On construction, spike data is loaded from a Kilosort output directory + and stored as read-only arrays. Plotting methods apply optional + filtering (noise exclusion, amplitude filtering, decimation) before + handing the data to a plot backend. + + Parameters + ---------- + sorter_path : str | Path + Path to a Kilosort sorter output directory. Must contain + exactly one ``kilosort*.log`` file used to detect the KS version. + + Attributes + ---------- + spike_times : np.ndarray + (num_spikes,) spike times (seconds for KS 1-3, samples for KS4). + spike_amplitudes : np.ndarray + (num_spikes,) spike amplitudes. + spike_depths : np.ndarray + (num_spikes,) spike depths along the probe (µm). + spike_templates : np.ndarray + (num_spikes,) template index assigned to each spike. + templates : np.ndarray + (num_templates, num_samples, num_channels) template waveforms. + channel_positions : np.ndarray + (num_channels, 2) x/y positions of each channel on the probe. + + TODO + ---- + - Evaluate memory cost of holding all arrays; consider lazy / mmap loading. + - Harmonise spike_times units (seconds everywhere). + """ + def __init__(self, sorter_path): + """Load spike data from a Kilosort output directory. + + Parameters + ---------- + sorter_path : str | Path + Path to the Kilosort sorter output. + + Raises + ------ + AssertionError + If the directory does not contain exactly one ``kilosort*.log`` + file, or if the loaded spike arrays have mismatched sizes. + """ + self.sorter_path = Path(sorter_path) + + log_file = list(self.sorter_path.glob("kilosort*.log")) + assert len(log_file) == 1 + self.ks_version = Path(log_file[0]).name.split(".")[0] + + func = kilosort_4.get_spikes_info_ks4 if self.ks_version == "kilosort4" else kilosort1_3.get_spikes_info_ks1_3 + + ( + self.spike_times, + self.spike_amplitudes, + self.spike_depths, + self.spike_templates, + self.templates, + self.channel_positions + ) = func( + self.sorter_path + ) + + assert self.spike_times.size == self.spike_amplitudes.size == self.spike_depths.size == self.spike_templates.size + + self.spike_times.flags.writeable = False + self.spike_amplitudes.flags.writeable = False + self.spike_depths.flags.writeable = False + self.spike_templates.flags.writeable = False + self.templates.flags.writeable = False + self.channel_positions.flags.writeable = False + + def _process_data( + self, + exclude_noise, + decimate, + filter_amplitude_mode, + filter_amplitude_values + ): + """Filter and subsample the loaded spike data. + + Operations are applied in order: decimation → noise exclusion → + amplitude filtering → masking. Decimation is applied first as a + performance knob to thin the full dataset before further filtering. + + Parameters + ---------- + exclude_noise : bool + If ``True``, spikes belonging to clusters labelled "noise" in + the Kilosort cluster groups file are removed. + decimate : int | False + Keep every *n*-th spike. Applied first to reduce the dataset + before noise/amplitude filters. ``False`` disables decimation. + filter_amplitude_mode : {"percentile", "absolute"} | None + How ``filter_amplitude_values`` is interpreted. + ``None`` disables amplitude filtering. + filter_amplitude_values : tuple of float + (low, high) bounds. Interpreted as percentile ranks or + absolute amplitude values depending on ``filter_amplitude_mode``. + + Returns + ------- + spike_times : np.ndarray + spike_amplitudes : np.ndarray + spike_depths : np.ndarray + spike_templates : np.ndarray + Filtered copies (views when no filtering is needed) of the + corresponding instance arrays. + """ + # Select a view for now, this may be copied depending on options (e.g. decimate) + spike_times = self.spike_times + spike_amplitudes = self.spike_amplitudes + spike_depths = self.spike_depths + spike_templates = self.spike_templates + + keep_bool_mask = None + + if exclude_noise: + keep_bool_mask = ~helpers.get_noise_mask( + self.sorter_path + ) + + if filter_amplitude_mode is not None: + assert filter_amplitude_mode in ["percentile", "absolute"] + + if filter_amplitude_mode == "percentile": + min_val, max_val = np.percentile( + spike_amplitudes, filter_amplitude_values + ) + else: + min_val, max_val = filter_amplitude_values + + if keep_bool_mask is None: + keep_bool_mask = np.ones(spike_amplitudes.size, dtype=bool) + + keep_bool_mask[spike_amplitudes < min_val] = False + keep_bool_mask[spike_amplitudes > max_val] = False + + if keep_bool_mask is not None: + spike_times = spike_times[keep_bool_mask] + spike_amplitudes = spike_amplitudes[keep_bool_mask] + spike_depths = spike_depths[keep_bool_mask] + spike_templates = spike_templates[keep_bool_mask] + + if decimate: + spike_times = spike_times[:: decimate] + spike_amplitudes = spike_amplitudes[:: decimate] + spike_depths = spike_depths[:: decimate] + spike_templates = spike_templates[:: decimate] + + return spike_times, spike_amplitudes, spike_depths, spike_templates + + def drift_map_plot_interactive( + self, + decimate=False, + exclude_noise=True, + amplitude_scaling="linear", + n_color_bins=20, + point_size=7.5, + filter_amplitude_mode=None, + filter_amplitude_values=(), + ): + """Create an interactive pyqtgraph-based drift map widget. + + Parameters + ---------- + decimate : int | False + Keep every *n*-th spike. ``False`` disables decimation. + exclude_noise : bool + Remove spikes labelled as noise. + amplitude_scaling : {"linear", "log2", "log10"} | tuple + Colour-scaling mode. A 2-tuple ``(min, max)`` fixes the + colour range explicitly. + n_color_bins : int + Number of grey-scale colour bins for amplitude. + point_size : float + Scatter-point diameter in pixels. + filter_amplitude_mode : {"percentile", "absolute"} | None + Amplitude filtering mode (see ``_process_data``). + filter_amplitude_values : tuple of float + Bounds for amplitude filtering. + + Returns + ------- + DriftmapPlotWidget + The pyqtgraph widget (already populated but not yet shown). + """ + ( + spike_times, + spike_amplitudes, + spike_depths, + spike_templates, + ) = self._process_data( + exclude_noise, + decimate, + filter_amplitude_mode, + filter_amplitude_values + ) + + self.plot = DriftmapPlotWidget( + spike_times, + spike_amplitudes, + spike_depths, + spike_templates, + self.templates, + self.channel_positions, + amplitude_scaling=amplitude_scaling, + n_color_bins=n_color_bins, + point_size=point_size, + sorter_path=self.sorter_path + ) + + return self.plot + + # ---------------------------------------------------------------------------------- + # TODO MATPLOTLIB + # ---------------------------------------------------------------------------------- + + def _drift_map_plot_matplotlib(self, + decimate=False, + exclude_noise=True, + log_transform_amplitudes=True, + filter_amplitude_mode=None, + filter_amplitude_values=(), + ): + ( + spike_times, + spike_amplitudes, + spike_depths, + spike_templates, + ) = self._process_data( + exclude_noise, + decimate, + filter_amplitude_mode, + filter_amplitude_values + ) + + fig = plt.figure(figsize=(10, 10 * (6 / 8))) + raster_axis = fig.add_subplot() + + _plot_kilosort_drift_map_raster( + spike_times, + spike_amplitudes, + spike_depths, + axis=raster_axis, + ) + + # histogram + if False: + hist_axis.set_xlabel("count") + raster_axis.set_xlabel("time") + hist_axis.set_ylabel("y position") + + bin_centers, counts = _compute_activity_histogram( + spike_amplitudes, spike_depths, weight_histogram_by_amplitude + ) + hist_axis.plot(counts, bin_centers, color="black", linewidth=1) + + return fig diff --git a/interactive/multi_session_drift_map.py b/interactive/multi_session_drift_map.py new file mode 100644 index 0000000..295f700 --- /dev/null +++ b/interactive/multi_session_drift_map.py @@ -0,0 +1,120 @@ +import math +from PySide6 import QtWidgets +from .driftmap_plot_widget import DriftmapPlotWidget + + +class MultiSessionDriftmapWidget(QtWidgets.QWidget): + """A grid container that displays multiple :class:`DriftmapPlotWidget` panels. + + Panels are laid out on an auto-computed (or user-specified) grid and + their scatter-plot y-axes are linked so scrolling / zooming in one + panel keeps all panels in sync. + + Parameters + ---------- + panels : list[DriftmapPlotWidget] + Drift-map widgets to arrange in the grid. + grid : tuple[int, int] | None + Explicit ``(n_rows, n_cols)`` layout. If ``None``, a roughly + square layout is computed automatically. + width : int + Width allocated per panel column (pixels). + height : int + Height allocated per panel row (pixels). + """ + + def __init__( + self, + panels: list[DriftmapPlotWidget], + grid: tuple[int, int] | None = None, + width: int = 700, + height: int = 820, + ): + super().__init__() + self.setWindowTitle("Drift map — multi session") + + num_panels = len(panels) + n_rows, n_cols = self._compute_grid_dimensions(num_panels, grid) + + self.resize(width * n_cols, height * n_rows) + self._populate_grid(panels, n_rows, n_cols) + self._link_y_axes(panels) + + self.show() + + @staticmethod + def _compute_grid_dimensions( + num_panels: int, + grid: tuple[int, int] | None, + ) -> tuple[int, int]: + """Return ``(n_rows, n_cols)`` for the panel layout. + + Parameters + ---------- + num_panels : int + Total number of panels to arrange. + grid : tuple[int, int] | None + User-specified ``(n_rows, n_cols)``. If ``None``, a roughly + square grid is computed automatically. + + Returns + ------- + n_rows, n_cols : int + """ + if grid is not None: + n_rows, n_cols = grid + if n_rows * n_cols != num_panels: + raise ValueError( + f"grid {grid} expects {n_rows * n_cols} panels " + f"but got {num_panels}" + ) + return n_rows, n_cols + + n_cols = math.ceil(math.sqrt(num_panels)) + n_rows = math.ceil(num_panels / n_cols) + + return n_rows, n_cols + + def _populate_grid( + self, + panels: list[DriftmapPlotWidget], + n_rows: int, + n_cols: int, + ) -> None: + """Place each panel into a :class:`QGridLayout`. + + Parameters + ---------- + panels : list[DriftmapPlotWidget] + Widgets to add. + n_rows, n_cols : int + Grid dimensions. + """ + grid_layout = QtWidgets.QGridLayout(self) + grid_layout.setContentsMargins(0, 0, 0, 0) + grid_layout.setSpacing(2) + + panel_idx = 0 + for row in range(n_rows): + for col in range(n_cols): + if panel_idx >= len(panels): + break + panels[panel_idx].setParent(self) + grid_layout.addWidget(panels[panel_idx], row, col) + panel_idx += 1 + + @staticmethod + def _link_y_axes(panels: list[DriftmapPlotWidget]) -> None: + """Link the scatter-plot y-axes of all panels to the first panel. + + This ensures scrolling or zooming the depth axis in any panel + updates all other panels to match. + + Parameters + ---------- + panels : list[DriftmapPlotWidget] + Must contain at least one panel. + """ + ref = panels[0].p_scatter + for panel in panels[1:]: + panel.p_scatter.setYLink(ref) \ No newline at end of file diff --git a/kilosort_4.py b/kilosort_4.py deleted file mode 100644 index f526fd9..0000000 --- a/kilosort_4.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from pathlib import Path - -import numpy as np -import helpers - - -def get_spikes_info_ks4( - sorter_output: Path, -): # -> tuple[np.ndarray, ...]: - - spike_times = np.load(sorter_output / "spike_times.npy") - spike_amplitudes = np.load(sorter_output / "amplitudes.npy") - spike_depths = np.load(sorter_output / "spike_positions.npy")[:, 1] - - return spike_times, spike_amplitudes, spike_depths diff --git a/ks_extractors/__init__.py b/ks_extractors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/helpers.py b/ks_extractors/helpers.py similarity index 51% rename from helpers.py rename to ks_extractors/helpers.py index 2b9ad2b..36cb8ee 100644 --- a/helpers.py +++ b/ks_extractors/helpers.py @@ -1,10 +1,11 @@ from __future__ import annotations from typing import TYPE_CHECKING -if TYPE_CHECKING: - from pathlib import Path +from pathlib import Path + import pandas as pd import numpy as np + def load_cluster_groups(cluster_path: Path) -> tuple[np.ndarray, ...]: """ Load kilosort `cluster_groups` file, that contains a table of @@ -47,32 +48,68 @@ def load_cluster_groups(cluster_path: Path) -> tuple[np.ndarray, ...]: return cluster_ids, cluster_groups -# This is such a jankily written function fix it -def exclude_noise(sorter_output, spike_times, spike_amplitudes, spike_depths): - """""" - if (cluster_path := sorter_output / "spike_clusters.npy").is_file(): # TODO: this can be csv?!?!? + +def get_noise_mask(sorter_output: Path) -> np.ndarray: + """Build a boolean mask identifying spikes that belong to noise clusters. + + Loads ``spike_clusters.npy`` and the cluster-groups file + (``cluster_groups.csv`` or ``cluster_group.tsv``) from the sorter + output directory. Spikes whose cluster is labelled *noise* + (group == 0) are marked ``True``. + + Parameters + ---------- + sorter_output : Path + Path to the Kilosort sorter output directory. + + Returns + ------- + np.ndarray + (num_spikes,) boolean array — ``True`` for spikes belonging to + a noise cluster. + + Raises + ------ + FileNotFoundError + If ``spike_clusters.npy`` is not found. + ValueError + If neither ``cluster_groups.csv`` nor ``cluster_group.tsv`` + exists in ``sorter_output``. + """ + if (cluster_path := sorter_output / "spike_clusters.npy").is_file(): spike_clusters = np.load(cluster_path) else: - # this is a pain to have here, I don't think this case is realistic. - raise NotImplementedError("spike clusters.csv does not exist. Under what circumstance is this? probably very old.") - # spike_clusters = spike_templates.copy() + raise FileNotFoundError("spike_clusters.npy does not exist.") - if ( # short circuit ensures cluster_path is assigned appropriately + if not ( (cluster_path := sorter_output / "cluster_groups.csv").is_file() or (cluster_path := sorter_output / "cluster_group.tsv").is_file() ): - cluster_ids, cluster_groups = load_cluster_groups(cluster_path) + raise ValueError( + f"`exclude_noise` is `True` but there is no `cluster_groups.csv/.tsv` " + f"in the sorting output at: {sorter_output}" + ) + + cluster_ids, cluster_groups = load_cluster_groups(cluster_path) + + noise_cluster_ids = cluster_ids[cluster_groups == 0] - noise_cluster_ids = cluster_ids[cluster_groups == 0] - not_noise_clusters_by_spike = ~np.isin(spike_clusters.ravel(), - noise_cluster_ids) - spike_times = spike_times[not_noise_clusters_by_spike] - spike_amplitudes = spike_amplitudes[not_noise_clusters_by_spike] - spike_depths = spike_depths[not_noise_clusters_by_spike] + exclude_bool_mask = np.isin(spike_clusters.ravel(), noise_cluster_ids) - return spike_times, spike_amplitudes, spike_depths + return exclude_bool_mask - raise ValueError( - f"`exclude_noise` is `True` but there is no `cluster_groups.csv` or `.tsv` " - f"in the sorting output at: {sorter_output}" - ) + +def get_pooled_amplitudes(paths: list[Path]) -> np.ndarray: + """Load and concatenate amplitudes.npy from multiple sorter output paths. + + Parameters + ---------- + paths : list of Path + List of sorter output directories, each containing amplitudes.npy. + + Returns + ------- + np.ndarray + Concatenated amplitudes from all paths. + """ + return np.concatenate([np.load(Path(p) / "amplitudes.npy").ravel() for p in paths]) diff --git a/kilosort1_3.py b/ks_extractors/kilosort1_3.py similarity index 99% rename from kilosort1_3.py rename to ks_extractors/kilosort1_3.py index a4b8da9..e66e988 100644 --- a/kilosort1_3.py +++ b/ks_extractors/kilosort1_3.py @@ -67,7 +67,7 @@ def get_spikes_info_ks1_3( params["temp_scaling_amplitudes"], ) - return params["spike_times"], spike_amplitudes, spike_depths + return params["spike_times"], spike_amplitudes, spike_depths, params["spike_templates"].squeeze(), unwhite_templates, params["channel_positions"] def _template_positions_amplitudes( diff --git a/ks_extractors/kilosort_4.py b/ks_extractors/kilosort_4.py new file mode 100644 index 0000000..b7604b5 --- /dev/null +++ b/ks_extractors/kilosort_4.py @@ -0,0 +1,27 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + +import numpy as np +from . import helpers + + +def get_spikes_info_ks4( + sorter_output: Path, +): # -> tuple[np.ndarray, ...]: + + # TODO: kept_spikes is not always the same size as other spike data + # Currently not used for loading + # kept_spikes = np.load(sorter_output / "kept_spikes.npy") + + spike_times = np.load(sorter_output / "spike_times.npy") + spike_amplitudes = np.load(sorter_output / "amplitudes.npy") + spike_depths = np.load(sorter_output / "spike_positions.npy")[:, 1] + spike_templates = np.load(sorter_output / "spike_templates.npy") # rename spike_tempaltes_idx? + + templates = np.load(sorter_output / "templates.npy") # rename unwihten? + channel_positions = np.load(sorter_output / "channel_positions.npy") + + return spike_times, spike_amplitudes, spike_depths, spike_templates, templates, channel_positions diff --git a/more_playing.py b/more_playing.py deleted file mode 100644 index f3cb1ce..0000000 --- a/more_playing.py +++ /dev/null @@ -1,162 +0,0 @@ -from driftmapviewer_new import get_drift_map_plot, _plot_kilosort_drift_map_raster, _filter_large_amplitude_spikes -import matplotlib.pyplot as plt -import kilosort1_3 -import kilosort_4 -import helpers -from pathlib import Path -import numpy as np - -# TODO: dont use gain, instead set clim -# TODO: removed localised peaks -# TODO: removed drift event and boundaries - -# TODO: it would be really cool and useful to hover over -# the plot and see the template waveform - -class DriftMapView(): - def __init__(self, sorter_path): - self.sorter_path = Path(sorter_path) - - log_file = list(self.sorter_path.glob("kilosort*.log")) - assert len(log_file) == 1 - self.ks_version = Path(log_file[0]).name.split(".")[0] - - # TODO: pay the cost once, then can plot a lot - # TOOD: compute cost of holding all in memory - - func = kilosort_4.get_spikes_info_ks4 if self.ks_version == "kilosort4" else kilosort1_3.get_spikes_info_ks1_3 - - self.spike_times, self.spike_amplitudes, self.spike_depths = func( - self.sorter_path - ) - - self.spike_times.flags.writeable = False - self.spike_amplitudes.flags.writeable = False - self.spike_depths.flags.writeable = False - - def _process_data( - self, - exclude_noise, - log_transform_amplitudes, - decimate, - only_include_large_amplitude_spikes, - large_amplitude_only_segment_size - ): - # start with a view, but we may end up with a copy depending on the settings - spike_times = self.spike_times - spike_amplitudes = self.spike_amplitudes - spike_depths = self.spike_depths - - - # min, max = np.percentile(spike_amplitudes, (90, 98)) - # spike_amplitudes = np.clip(spike_amplitudes, min, max) - # This makes the assumption that there will never be different .csv and .tsv files - # in the same sorter output (this should never happen, there will never even be two). - # Though can be saved as .tsv, it seems the .csv is also tab formatted as far as pandas is concerned. - - - # TODO: this is super weird, can be improved? - if log_transform_amplitudes: - spike_amplitudes = np.log(spike_amplitudes) # TODO: give optional (None, 2 or 10) - - # Calculate the amplitude range for plotting first, so the scale is always the - # same across all options (e.g. decimation) which helps with interpretability. - amplitude_range_all_spikes = ( - spike_amplitudes.min(), - spike_amplitudes.max(), - ) - - # TODO: move exclude noise here! - if exclude_noise: - spike_times, spike_amplitudes, spike_depths = helpers.exclude_noise( - self.sorter_path, spike_times, spike_amplitudes, spike_depths - ) - - if decimate: - spike_times = spike_times[:: decimate] - spike_amplitudes = spike_amplitudes[:: decimate] - spike_depths = spike_depths[:: decimate] - - if only_include_large_amplitude_spikes: - - spike_times, spike_amplitudes, spike_depths = _filter_large_amplitude_spikes( - spike_times, spike_amplitudes, spike_depths, - large_amplitude_only_segment_size - ) - - return spike_times, spike_amplitudes, spike_depths, amplitude_range_all_spikes - - def get_drift_map_plot(self, - only_include_large_amplitude_spikes=True, - decimate=False, - exclude_noise=True, - log_transform_amplitudes=True, - large_amplitude_only_segment_size=800, - ): - ( - spike_times, - spike_amplitudes, - spike_depths, - amplitude_range_all_spikes - ) = self._process_data( - exclude_noise, - log_transform_amplitudes, - decimate, - only_include_large_amplitude_spikes, - large_amplitude_only_segment_size - ) - - fig = plt.figure(figsize=(10, 10 * (6 / 8))) - raster_axis = fig.add_subplot() - - _plot_kilosort_drift_map_raster( - spike_times, - spike_amplitudes, - spike_depths, - amplitude_range_all_spikes, - axis=raster_axis, - ) - - return fig - - def get_1d_histogram_plot( - self, - ): - pass - - def get_2d_histogram_plot(self): - pass - - def get_combined_drift_map_plot( - self - ): - pass - -for file in [ - r"C:\Users\Jzimi\Desktop\derivatives\1119617_LSE1_shank12_g0\0\sorter_output", -]: - plotter = DriftMapView( - file - ) - - fig = plotter.get_drift_map_plot( - only_include_large_amplitude_spikes=True, # exclude amplitude outliers? maybe just do this instead of doing this segmented way. - decimate=False, - exclude_noise=False, - log_transform_amplitudes=True - ) - -plt.show() - - -if False: - plot = get_drift_map_plot( - r"C:\Users\Joe\PycharmProjects\viewephys3\kilosort4_output\sorter_output", - only_include_large_amplitude_spikes=True, - add_histogram_plot=True, - weight_histogram_by_amplitude=True, - decimate=False, - exclude_noise=True - ) - - plt.show() diff --git a/mpl_plotting/__init__.py b/mpl_plotting/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/driftmapviewer_new.py b/mpl_plotting/driftmapviewer_new.py similarity index 84% rename from driftmapviewer_new.py rename to mpl_plotting/driftmapviewer_new.py index 1e270c9..025a585 100644 --- a/driftmapviewer_new.py +++ b/mpl_plotting/driftmapviewer_new.py @@ -9,6 +9,7 @@ import matplotlib.pyplot as plt from scipy import stats +from ks_extractors import kilosort_4, kilosort1_3 def get_spikes_info(sorter_output, ks_version, exclude_noise): @@ -81,7 +82,6 @@ def get_drift_map_plot( else: raster_axis = fig.add_subplot() - print(amplitude_range_all_spikes) _plot_kilosort_drift_map_raster( spike_times, spike_amplitudes, @@ -108,44 +108,6 @@ def get_drift_map_plot( return fig -# TODO: not sure about this, much point in computing as segments? Just combine? Its just another parameter to track... -# hmm I guess it makes sense if going through brain regions, and if want to do the entire thing, you can set segment to probe... -def _filter_large_amplitude_spikes( - spike_times: np.ndarray, - spike_amplitudes: np.ndarray, - spike_depths: np.ndarray, - large_amplitude_only_segment_size, -) -> tuple[np.ndarray, ...]: - """ - Return spike properties with only the largest-amplitude spikes included. The probe - is split into segments, and within each segment the mean and std computed. - Any spike less than 1.5x the standard deviation in amplitude of it's segment is excluded - Splitting the probe is only done for the exclusion step, the returned array are flat. - - Takes as input arrays `spike_times`, `spike_depths` and `spike_amplitudes` and returns - copies of these arrays containing only the large amplitude spikes. - """ - spike_bool = np.zeros_like(spike_amplitudes, dtype=bool) - - segment_size_um = large_amplitude_only_segment_size - probe_segments_left_edges = np.arange(np.floor(spike_depths.max() / segment_size_um) + 1) * segment_size_um - - for segment_left_edge in probe_segments_left_edges: - segment_right_edge = segment_left_edge + segment_size_um - - spikes_in_seg = np.where( - np.logical_and(spike_depths >= segment_left_edge, spike_depths < segment_right_edge) - )[0] - spike_amps_in_seg = spike_amplitudes[spikes_in_seg] - is_high_amplitude = spike_amps_in_seg > np.mean(spike_amps_in_seg) + 1.5 * np.std(spike_amps_in_seg, ddof=1) - - spike_bool[spikes_in_seg] = is_high_amplitude - - spike_times = spike_times[spike_bool] - spike_amplitudes = spike_amplitudes[spike_bool] - spike_depths = spike_depths[spike_bool] - - return spike_times, spike_amplitudes, spike_depths def _plot_kilosort_drift_map_raster( spike_times: np.ndarray,