diff --git a/.gitignore b/.gitignore index b7faf40..09be0ef 100644 --- a/.gitignore +++ b/.gitignore @@ -182,9 +182,9 @@ cython_debug/ .abstra/ # Visual Studio Code -# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore -# and can be added to the global gitignore or merged into this file. However, if you prefer, +# and can be added to the global gitignore or merged into this file. However, if you prefer, # you could uncomment the following to ignore the entire vscode folder # .vscode/ diff --git a/MANIFEST.in b/MANIFEST.in index e16ea33..25dc798 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -6,3 +6,4 @@ recursive-exclude * __pycache__ recursive-exclude * *.py[co] recursive-exclude docs * recursive-exclude tests * +recursive-exclude examples * diff --git a/docs/requirements.txt b/docs/requirements.txt index 1b2fdb9..605d058 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,4 +6,5 @@ pydata-sphinx-theme setuptools-scm sphinx sphinx-autodoc-typehints +sphinx-design sphinx-sitemap diff --git a/docs/source/api_index.rst b/docs/source/api_index.rst deleted file mode 100644 index f7b449e..0000000 --- a/docs/source/api_index.rst +++ /dev/null @@ -1,25 +0,0 @@ -API -=== - -math ----- - -.. currentmodule:: driftmap_viewer.math - -.. autosummary:: - :toctree: api_generated - :template: function.rst - - add_two_integers - subtract_two_integers - -greetings ---------- - -.. currentmodule:: driftmap_viewer.greetings - -.. autosummary:: - :toctree: api_generated - :template: class.rst - - Greetings diff --git a/docs/source/conf.py b/docs/source/conf.py index 271c34d..485f1a2 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -46,6 +46,7 @@ "sphinx_sitemap", "myst_parser", "nbsphinx", + "sphinx_design", ] # Configure the myst parser to enable cool markdown features @@ -64,6 +65,8 @@ "strikethrough", "substitution", "tasklist", + "attrs_block", + "attrs_inline", ] # Automatically add anchors to markdown headings myst_heading_anchors = 3 @@ -122,4 +125,10 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -# html_static_path = ['_static'] +html_static_path = ['_static'] + +html_css_files = [ + 'css/custom.css', +] + +html_show_sourcelink = False diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md deleted file mode 100644 index 855d15f..0000000 --- a/docs/source/getting_started.md +++ /dev/null @@ -1,3 +0,0 @@ -# Getting started - -Documentation placeholder diff --git a/docs/source/index.md b/docs/source/index.md new file mode 100644 index 0000000..0c81649 --- /dev/null +++ b/docs/source/index.md @@ -0,0 +1,60 @@ +:html_theme.sidebar_secondary.remove: + +```{raw} html +
+``` +# driftmap_viewer +```{raw} html +
+``` + +Driftmap Viewer (TODO: NAME) is a tool to visualise and save drift maps from kilosort +or SpikeInterface and interactively compare these across sessions. + +The main use case of Driftmap Viewer is to check the alignment of sessions prior to +performing inter-session Unit Matching or data concatenation. In the below example, +(XXX) it is clear the two sessions are well aligned, with drift maps looking similar and +templates clearly matching at the same position of the probe. + +Driftmap viewer can be used in this interactive view () or through matplotlib +plots that can be used to save PDF over an entire experiment for quick checking + +::::{grid} 1 1 3 3 +:gutter: 4 + +:::{grid-item-card} {fas}`book;sd-text-primary` Articles +:link: pages/articles/index +:link-type: doc + +Guides for the interactive and matplotlib viewers, and how drift-map parameters are calculated. +::: + +:::{grid-item-card} {fas}`lightbulb;sd-text-primary` Examples +:link: pages/examples/index +:link-type: doc + +Worked examples using ``driftmap_viewer`` in practice. +::: + +:::{grid-item-card} {fas}`code;sd-text-primary` API +:link: pages/api_index +:link-type: doc + +Full Python API reference. +::: + +:::: + +``driftmap_viewer`` loads Kilosort sorter output and creates +interactive (pyqtgraph) or static (matplotlib) drift map plots, +making it easy to inspect electrode drift across a recording session. + +```{toctree} +:maxdepth: 2 +:caption: index +:hidden: + +pages/articles/index +pages/examples/index +pages/api_index +``` diff --git a/docs/source/index.rst b/docs/source/index.rst deleted file mode 100644 index d895be9..0000000 --- a/docs/source/index.rst +++ /dev/null @@ -1,46 +0,0 @@ -.. driftmap_viewer documentation master file, created by - sphinx-quickstart on Fri Dec 9 14:12:42 2022. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - -Welcome to driftmap_viewer's documentation! -========================================================= - -.. toctree:: - :maxdepth: 2 - :caption: Contents: - - getting_started - api_index - -By default the documentation includes the following sections: - -* Getting started. Here you could describe the basic functionalities of your package. To modify this page, edit the file ``docs/source/getting_started.md``. -* API: here you can find the auto-generated documentation of your package, which is based on the docstrings in your code. To modify which modules/classes/functions are included in the API documentation, edit the file ``docs/source/api_index.rst``. - -You can create additional sections with narrative documentation, -by adding new ``.md`` or ``.rst`` files to the ``docs/source`` folder. -These files should start with a level-1 (H1) header, -which will be used as the section title. Sub-sections can be created -with lower-level headers (H2, H3, etc.) within the same file. - -To include a section in the rendered documentation, -add it to the ``toctree`` directive in this (``docs/source/index.rst``) file. - -For example, you could create a ``docs/source/installation.md`` file -and add it to the ``toctree`` like this: - -.. code-block:: rst - - .. toctree:: - :maxdepth: 2 - :caption: Contents: - - getting_started - installation - api_index - -Index & Search --------------- -* :ref:`genindex` -* :ref:`search` diff --git a/docs/source/pages/api_index.rst b/docs/source/pages/api_index.rst new file mode 100644 index 0000000..0dd896f --- /dev/null +++ b/docs/source/pages/api_index.rst @@ -0,0 +1,10 @@ +.. _API_Reference: + +API Reference +============= + +.. currentmodule:: driftmap_viewer.interactive.driftmap_view + +.. autoclass:: DriftMapView + :members: + :undoc-members: diff --git a/docs/source/pages/articles/how-parameters-are-calculated.md b/docs/source/pages/articles/how-parameters-are-calculated.md new file mode 100644 index 0000000..44929e7 --- /dev/null +++ b/docs/source/pages/articles/how-parameters-are-calculated.md @@ -0,0 +1,22 @@ +# How parameters are calculated + +Drfitmap Viewer supports kilosort outputs or spikeinterface outputs directly. Unfortunately, +the term 'amplitude' does not always map to the exact spike times. For kilosort, the term 'amplitude' +are not true amplitudes in uV but XXX, and will change depending on the version. + +# Amplitudes + +KS1-3 +KS4 + +# Why are whitened templates show for KS output? + +For spikeinterface, the template is the . For KS, it is the XXX. It is possible +to unwhiten the templates. However, for KS2.5 and Ks3, this matrix is not save dproperly. +Therefore, for consistent whitetneied templates are always used. If you would like unwhitented +templates, please get in touch. + + +NOTE THAT SPIKE_CLUSTERS IS ATTEMPTED TO BE USED +NOTE THAT AMPLITUDES ARE ALWAYS POSITIVE +TODO: these amplitudes are not scaled by gain / offset, but this doesn't matter for our purposes diff --git a/docs/source/pages/articles/how-to-use-driftmapviewer.md b/docs/source/pages/articles/how-to-use-driftmapviewer.md new file mode 100644 index 0000000..d311002 --- /dev/null +++ b/docs/source/pages/articles/how-to-use-driftmapviewer.md @@ -0,0 +1,38 @@ +# Using the interactive viewer + +Driftmap viewer is a lightweight tool. We can create the viewer instance +by supplying a path to a sorting output, or a SpikeInterface sortinganalyzer object. + +For example: + +```python + +``` + +it just shows the template for that spike, it does NOT show the spike or even the scaled template, +due to inconsistencies in XXX. + +This will load all key features (e.g. spike imes, ampliudes) into memory, once. + +Next, we can plot using driftmap interactive or matplotlib. These take all of the same arguments, +except for matpltlib which also can show a 1D activity history. + +As part of XXX, the data is processed. + +WARNINGS: designed for NP1 probes, will work for Cam Neurotech or NeuroNexus but not well tested. Please get in touch. + +# Using the interactive Viewer + +Note for SI: +max_spikes_per_unit=1_000_000, # This determines the number of spikes that will appear on the SI drift plot + +# for KS, +# TODO: these amplitudes are not scaled by gain / offset, but this doesn't matter for our purposes + +See how amplitudes are calculated +Why uniwhitened templates are displayed (except spikeinterface). + +# Using matplotlib + + +# See the API documentation for each thing diff --git a/docs/source/pages/articles/index.md b/docs/source/pages/articles/index.md new file mode 100644 index 0000000..5310122 --- /dev/null +++ b/docs/source/pages/articles/index.md @@ -0,0 +1,40 @@ +:html_theme.sidebar_secondary.remove: + +(articles)= +# Articles + +::::{grid} 1 1 2 2 +:gutter: 4 + +:::{grid-item-card} {fas}`desktop;sd-text-primary` Using the interactive viewer +:link: interactive-viewer +:link-type: doc + +Launch and navigate the pyqtgraph-based interactive drift map. +::: + +:::{grid-item-card} {fas}`chart-line;sd-text-primary` Using the matplotlib viewer +:link: matplotlib-viewer +:link-type: doc + +Generate static drift map figures with matplotlib. +::: + +:::{grid-item-card} {fas}`calculator;sd-text-primary` How parameters are calculated +:link: how-parameters-are-calculated +:link-type: doc + +Details on how spike depths, amplitudes and colour scaling are derived. +::: + +:::: + +```{toctree} +:maxdepth: 2 +:caption: Articles +:hidden: + +interactive-viewer +matplotlib-viewer +how-parameters-are-calculated +``` diff --git a/docs/source/pages/examples/creating-pdf.md b/docs/source/pages/examples/creating-pdf.md new file mode 100644 index 0000000..da892be --- /dev/null +++ b/docs/source/pages/examples/creating-pdf.md @@ -0,0 +1,3 @@ +# Creating a PDF from an experiment + +*Coming soon.* diff --git a/docs/source/pages/examples/index.md b/docs/source/pages/examples/index.md new file mode 100644 index 0000000..59f8e7d --- /dev/null +++ b/docs/source/pages/examples/index.md @@ -0,0 +1,32 @@ +:html_theme.sidebar_secondary.remove: + +(examples)= +# Examples + +::::{grid} 1 1 2 2 +:gutter: 4 + +:::{grid-item-card} {fas}`columns;sd-text-primary` Multi-widgets for comparing sessions +:link: multi-widget-comparison +:link-type: doc + +Use multiple drift map widgets side-by-side to compare sessions. +::: + +:::{grid-item-card} {fas}`file-pdf;sd-text-primary` Creating a PDF from an experiment +:link: creating-pdf +:link-type: doc + +Export drift map figures to a PDF report for an experiment. +::: + +:::: + +```{toctree} +:maxdepth: 2 +:caption: Examples +:hidden: + +multi-widget-comparison +creating-pdf +``` diff --git a/docs/source/pages/examples/multi-widget-comparison.md b/docs/source/pages/examples/multi-widget-comparison.md new file mode 100644 index 0000000..69ab150 --- /dev/null +++ b/docs/source/pages/examples/multi-widget-comparison.md @@ -0,0 +1,3 @@ +# Multi-widgets for comparing sessions + +*Coming soon.* diff --git a/driftmap_viewer/__init__.py b/driftmap_viewer/__init__.py new file mode 100644 index 0000000..76b2298 --- /dev/null +++ b/driftmap_viewer/__init__.py @@ -0,0 +1,3 @@ +from driftmap_viewer.driftmap_view import DriftMapView +from driftmap_viewer.interactive.multi_session_drift_map import MultiSessionDriftmapWidget +from driftmap_viewer.amplitudes import get_amplitudes diff --git a/driftmap_viewer/amplitudes.py b/driftmap_viewer/amplitudes.py new file mode 100644 index 0000000..470261d --- /dev/null +++ b/driftmap_viewer/amplitudes.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import spikeinterface as si + +from driftmap_viewer.data_loader import DataLoader + + +def get_amplitudes( + list_of_path_or_analyzer: list[Path | si.SortingAnalyzer], + exclude_noise: bool = False, + concatenate: bool = False, +) -> np.ndarray | list[np.ndarray]: + """Load and concatenate amplitudes.npy from multiple sorter output paths. + + Parameters + ---------- + list_of_path_or_analyzer + List of sorter output directories, each containing amplitudes.npy. + concatenate + If ``True``, concatenate all amplitudes into a single array. + + Returns + ------- + np.ndarray or list of np.ndarray + Concatenated amplitudes from all paths. + """ + all_spike_amplitudes = [] + + for path_or_analyzer in list_of_path_or_analyzer: + loader = DataLoader(path_or_analyzer) + + processed_data = loader.get_processed_data( + exclude_noise, + decimate=False, + filter_amplitude_mode=None, + filter_amplitude_values=None, + ) + + all_spike_amplitudes.append(processed_data.spike_amplitudes) + + if concatenate: + all_spike_amplitudes = np.concatenate(all_spike_amplitudes) + + return all_spike_amplitudes diff --git a/driftmap_viewer/data_loader.py b/driftmap_viewer/data_loader.py new file mode 100644 index 0000000..0271f65 --- /dev/null +++ b/driftmap_viewer/data_loader.py @@ -0,0 +1,150 @@ +from pathlib import Path +from typing import Callable + +import numpy as np +import spikeinterface as si + +from driftmap_viewer.data_model import DataModel +from driftmap_viewer.extractors import ( + analyzer_helpers, + kilosort1_3, + kilosort4, + kilosort_helpers, +) + + +class DataLoader: + """""" + + def __init__(self, path_or_analyzer: Path | si.SortingAnalyzer) -> None: + """ """ + self.path_or_analyzer = path_or_analyzer + + # Get the data loading function depending on if + # we are analyzer or kilosort output + func: Callable + if isinstance(path_or_analyzer, si.SortingAnalyzer): + func = analyzer_helpers.get_sorting_analyzer + else: + ks_version = kilosort_helpers.get_ks_version(Path(path_or_analyzer)) + func = ( + kilosort4.get_spikes_info_ks4 + if ks_version == "kilosort4" + else kilosort1_3.get_spikes_info_ks1_3 + ) + + # Load the required data and check sizes match (one entry per spike) + ( + self._spike_times, + self._spike_amplitudes, + self._spike_depths, + self._spike_clusters, + self.templates, + self.channel_locations, + ) = func(path_or_analyzer) + + assert ( + self._spike_times.size + == self._spike_amplitudes.size + == self._spike_depths.size + == self._spike_clusters.size + ) + assert self.channel_locations.shape[0] > self.channel_locations.shape[1] + + self._spike_times.flags.writeable = False + self._spike_amplitudes.flags.writeable = False + self._spike_depths.flags.writeable = False + self._spike_clusters.flags.writeable = False + self.templates.flags.writeable = False + self.channel_locations.flags.writeable = False + + def get_processed_data( + self, exclude_noise, decimate, filter_amplitude_mode, filter_amplitude_values + ): + """Filter and subsample the loaded spike data. + + Operations are applied in order: decimation → noise exclusion → + amplitude filtering → masking. Decimation is applied first as a + performance knob to thin the full dataset before further filtering. + + Parameters + ---------- + exclude_noise : + If ``True``, spikes belonging to clusters labelled "noise" in + the Kilosort cluster groups file are removed. + decimate : int | False + Keep every *n*-th spike. Applied first to reduce the dataset + before noise/amplitude filters. ``False`` disables decimation. + filter_amplitude_mode : {"percentile", "absolute"} | None + How ``filter_amplitude_values`` is interpreted. + ``None`` disables amplitude filtering. + filter_amplitude_values : tuple of float + (low, high) bounds. Interpreted as percentile ranks or + absolute amplitude values depending on ``filter_amplitude_mode``. + + Returns + ------- + spike_times : np.ndarray + spike_amplitudes : np.ndarray + spike_depths : np.ndarray + spike_clusters : np.ndarray + Filtered copies (views when no filtering is needed) of the + corresponding instance arrays. + """ + # Select a view for now, this may be copied depending on options (e.g. decimate) + spike_times = self._spike_times + spike_amplitudes = self._spike_amplitudes + spike_depths = self._spike_depths + spike_clusters = self._spike_clusters + + keep_bool_mask = None + + # First, exclude spikes from units labeled as "noise" + if exclude_noise: + if isinstance(self.path_or_analyzer, si.SortingAnalyzer): + keep_bool_mask = ~analyzer_helpers.get_noise_mask( + exclude_noise, spike_clusters, self.path_or_analyzer + ) + else: + keep_bool_mask = ~kilosort_helpers.get_noise_mask( + spike_clusters, self.path_or_analyzer + ) + + # Next, filter spikes based on amplitude + if filter_amplitude_mode is not None: + assert filter_amplitude_mode in ["percentile", "absolute"] + + if filter_amplitude_mode == "percentile": + min_val, max_val = np.percentile( + spike_amplitudes, filter_amplitude_values + ) + else: + min_val, max_val = filter_amplitude_values + + if keep_bool_mask is None: + keep_bool_mask = np.ones(spike_amplitudes.size, dtype=bool) + + keep_bool_mask[spike_amplitudes < min_val] = False + keep_bool_mask[spike_amplitudes > max_val] = False + + # mask exclude_noise / filtered amplitudes + if keep_bool_mask is not None: + spike_times = spike_times[keep_bool_mask] + spike_amplitudes = spike_amplitudes[keep_bool_mask] + spike_depths = spike_depths[keep_bool_mask] + spike_clusters = spike_clusters[keep_bool_mask] + + if decimate: + spike_times = spike_times[::decimate] + spike_amplitudes = spike_amplitudes[::decimate] + spike_depths = spike_depths[::decimate] + spike_clusters = spike_clusters[::decimate] + + return DataModel( + spike_times, + spike_amplitudes, + spike_depths, + spike_clusters, + self.templates, + self.channel_locations, + ) diff --git a/driftmap_viewer/data_model.py b/driftmap_viewer/data_model.py new file mode 100644 index 0000000..ba9ca67 --- /dev/null +++ b/driftmap_viewer/data_model.py @@ -0,0 +1,191 @@ +import warnings + +import matplotlib.pyplot as plt +import numpy as np + + +class DataModel: + def __init__( + self, + spike_times, + spike_amplitudes, + spike_depths, + spike_clusters, + templates, + channel_locations, + ): + self.spike_times = spike_times + self.spike_depths = spike_depths + self.spike_amplitudes = spike_amplitudes + self.spike_clusters = spike_clusters + self.templates = templates + self.channel_locations = channel_locations + + def get_scatter_data(self): + return self.spike_times, self.spike_depths, self.spike_amplitudes + + def get_template_id(self, spike_idx): + return self.spike_clusters[spike_idx] + + def get_template_heatmap(self, spike_index, view_mode): + """ """ + # Extract the template for this spike + template_idx = self.spike_clusters[spike_index] + template = self.templates[template_idx, :, :] + mid_idx = int(template.shape[0] / 2) + + # Next we need to find the shank the template is on. For KS, + # signal can also be found on other shanks but this confuses the + # visualisation + # Find the channel with maximum signal + max_chan_idx = np.argmax(np.max(np.abs(template), axis=0)) + max_signal_x_loc = self.channel_locations[max_chan_idx, 0] + + # Find other channels in the shank column. Because we are working on + # KS outputs, we have no knowledge of the probe, so we have to guess. + # Based on the column vs. shank space for a number of popular probes + # (NP1: 1 shank, 70um across, NP2: 250um between shank, shank width ~70um, + # Cambridge Neurotech: shank widths ~80 µm, shank spacing ~200um+, + # NeuroNexus: does have some shank widths at 100-120um)), in which + # this will fail. The simplest solution is to document and + # down the line expose this parameter. + COL_CUTOFF_UM = 125 + + chan_x_locs = np.unique(self.channel_locations[:, 0]) + + chan_x_spacings = np.diff(chan_x_locs) + if np.any( + np.logical_and(chan_x_spacings > COL_CUTOFF_UM, chan_x_spacings < 150) + ): + warnings.warn( + f"The spacings between x-locations: {chan_x_spacings} makes " + f"it difficult to distinguish between channel and shank spacing. " + f"The cutoff is {COL_CUTOFF_UM}, less than this is assumed to be " + f"two columns of channels on the same shank." + ) + + valid_pos = chan_x_locs[np.abs(chan_x_locs - max_signal_x_loc) < COL_CUTOFF_UM] + + shank_select = np.zeros(self.channel_locations.shape[0], dtype=bool) + for pos in valid_pos: + shank_select = np.logical_or( + shank_select, self.channel_locations[:, 0] == pos + ) + + # Often the contact positions are not organised contiguous + # along the y-dimension and need resorting + sort_idx = np.argsort(self.channel_locations[shank_select, 1], axis=0) + + # Select the shank of interest ordered by depth + template = template[:, shank_select] + template = template[:, sort_idx] + + # Either display only the channels with signal on, or all channels but + # non-signal channels are empty. Using the threshold ==0 works well for + # SI analyzer and whitened KS templates, less well for un-whitened KS + # templates which have nonzero signal on all channel, but for which no + # clear threshold exists. + if view_mode == "heatmap_all_channels": + template = template.copy() + template[:, template[mid_idx, :] == 0] = np.nan + else: + contains_data_idx = np.where(template[mid_idx, :] != 0)[0] + template = template[:, contains_data_idx] + + return template + + # TODO: CHECK THIS + def compute_amplitude_colors( + self, amplitude_scaling, n_color_bins, unit_normalise=False + ): + """Map spike amplitudes to RGBA colours via grey-scale binning. + + Parameters + ---------- + amplitude_scaling : {"linear", "log2", "log10"} | tuple + Scaling mode. A 2-tuple ``(min, max)`` fixes the colour + range explicitly. + n_color_bins : int + Number of grey-scale bins. + + Returns + ------- + np.ndarray + (num_spikes, 4) uint8 RGBA values + """ + amp_values = np.abs(self.spike_amplitudes) + + if isinstance(amplitude_scaling, tuple): + amp_min, amp_max = amplitude_scaling + else: + if amplitude_scaling == "log2": + amp_values = np.log2(np.maximum(amp_values, np.finfo(float).eps)) + + elif amplitude_scaling == "log10": + amp_values = np.log10(np.maximum(amp_values, np.finfo(float).eps)) + + amp_min, amp_max = amp_values.min(), amp_values.max() + + color_bins = np.linspace(amp_min, amp_max, n_color_bins) + gray_colors = plt.get_cmap("gray")(np.linspace(0, 1, n_color_bins))[::-1] + bin_indices = np.clip( + np.searchsorted(color_bins, amp_values, side="right") - 1, + 0, + n_color_bins - 2, + ) + + colors = gray_colors[bin_indices] + + if not unit_normalise: + colors *= 255 + colors = colors.astype(np.uint8) + + return colors + + def compute_activity_histogram( + self, weight_histogram_by_amplitude: bool + ) -> tuple[np.ndarray, ...]: + """ + Compute the activity histogram for the kilosort drift map's left-side plot. + + Parameters + ---------- + weight_histogram_by_amplitude : bool + If `True`, the spike amplitudes are taken into + consideration when generating the histogram. + The amplitudes are scaled to the range [0, 1] + then summed for each bin, to generate the + histogram values. If `False`, counts + (i.e. num spikes per bin) are used. + + Returns + ------- + bin_centers : np.ndarray + The spatial bin centers (probe depth) for the histogram. + values : np.ndarray + The histogram values. If + `weight_histogram_by_amplitude` is `False`, + these values represent are counts, otherwise + they are counts weighted by amplitude. + """ + # `spike amplitudes should be high precision as many values are summed. + spike_amplitudes = self.spike_amplitudes.astype(np.float64) + + bin_um = 2 + bins = np.arange( + self.spike_depths.min() - bin_um, self.spike_depths.max() + bin_um, bin_um + ) + values, bins = np.histogram(self.spike_depths, bins=bins) + bin_centers = (bins[:-1] + bins[1:]) / 2 + + if weight_histogram_by_amplitude: + bin_indices = np.digitize(self.spike_depths, bins, right=True) - 1 + values = np.zeros(bin_indices.max() + 1, dtype=np.float64) + + scaled_spike_amplitudes = ( + spike_amplitudes - spike_amplitudes.min() + ) / np.ptp(spike_amplitudes) + + np.add.at(values, bin_indices, scaled_spike_amplitudes) + + return bin_centers, values diff --git a/driftmap_viewer/driftmap_view.py b/driftmap_viewer/driftmap_view.py new file mode 100644 index 0000000..5844aaa --- /dev/null +++ b/driftmap_viewer/driftmap_view.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +from pathlib import Path + +from matplotlib.figure import Figure + +from driftmap_viewer import mpl_plotting +from driftmap_viewer.data_loader import DataLoader +from driftmap_viewer.interactive.driftmap_plot_widget import DriftmapPlotWidget + +# test ideas: +# check signatures match default args between itneractive and matplotlib + + +class DriftMapView: + """Load Kilosort sorter output and provide interactive or static drift map plots. + + On construction, spike data is loaded from a Kilosort output directory + and stored as read-only arrays. Plotting methods apply optional + filtering (noise exclusion, amplitude filtering, decimation) before + handing the data to a plot backend. + + Parameters + ---------- + sorter_path + Path to a Kilosort sorter output directory. Must contain + exactly one ``kilosort*.log`` file used to detect the KS version. + + Attributes + ---------- + spike_times + (num_spikes,) spike times (seconds for KS 1-3, samples for KS4). + spike_amplitudes + (num_spikes,) spike amplitudes. + spike_depths + (num_spikes,) spike depths along the probe (µm). + spike_clusters + (num_spikes,) template index assigned to each spike. + templates + (num_templates, num_samples, num_channels) template waveforms. + channel_locations + (num_channels, 2) x/y positions of each channel on the probe. + """ + + def __init__(self, sorter_path: str | Path) -> None: + """Load spike data from a Kilosort output directory. + + Parameters + ---------- + sorter_path + Path to the Kilosort sorter output. + + Raises + ------ + AssertionError + If the directory does not contain exactly one ``kilosort*.log`` + file, or if the loaded spike arrays have mismatched sizes. + """ + self.data_loader = DataLoader(sorter_path) # TODO: rename + + def drift_map_plot_interactive( + self, + decimate: int | bool = False, + exclude_noise: bool = False, + amplitude_scaling: str | tuple[float, float] = "linear", + n_color_bins: int = 20, + point_size: float = 7.5, + filter_amplitude_mode: str | None = None, + filter_amplitude_values: tuple[float, ...] = (), + ) -> DriftmapPlotWidget: + """Create an interactive pyqtgraph-based drift map widget. + + Parameters + ---------- + decimate + Keep every *n*-th spike. ``False`` disables decimation. + exclude_noise + Remove spikes labelled as noise. + amplitude_scaling + Colour-scaling mode or explicit ``(min, max)`` range. + n_color_bins + Number of grey-scale colour bins for amplitude. + point_size + Scatter-point diameter in pixels. + filter_amplitude_mode + Amplitude filtering mode. + filter_amplitude_values + Bounds for amplitude filtering. + + Returns + ------- + DriftmapPlotWidget + The pyqtgraph widget. This is already populated but not yet + shown, use app.exec() to display. + """ + processed_data = self.data_loader.get_processed_data( + exclude_noise, decimate, filter_amplitude_mode, filter_amplitude_values + ) + + self.plot = DriftmapPlotWidget( + processed_data, + amplitude_scaling=amplitude_scaling, + n_color_bins=n_color_bins, + point_size=point_size, + ) + + return self.plot + + def drift_map_plot_matplotlib( + self, + decimate: int | bool = False, + exclude_noise: bool | str = False, + amplitude_scaling: str | tuple[float, float] = "linear", + n_color_bins: int = 20, + point_size: float = 7.5, + filter_amplitude_mode: str | None = None, + filter_amplitude_values: tuple[float, ...] = (), + add_histogram_plot: bool = False, + weight_histogram_by_amplitude: bool = False, + ) -> Figure: + """""" + processed_data = self.data_loader.get_processed_data( + exclude_noise, decimate, filter_amplitude_mode, filter_amplitude_values + ) + + fig = mpl_plotting.plot_matplotlib( + processed_data, + amplitude_scaling, + n_color_bins, + point_size, + add_histogram_plot, + weight_histogram_by_amplitude, + ) + + return fig diff --git a/driftmap_viewer/driftmapviewer.py b/driftmap_viewer/driftmapviewer.py deleted file mode 100644 index a2c5a6f..0000000 --- a/driftmap_viewer/driftmapviewer.py +++ /dev/null @@ -1,521 +0,0 @@ -from pathlib import Path -from .base import BaseWidget, to_attr -import matplotlib.axis -import scipy.signal -from spikeinterface.core import read_python -import numpy as np -import pandas as pd - -import matplotlib.pyplot as plt -from scipy import stats - - -class KilosortDriftMapWidget(BaseWidget): - """ - Create a drift map plot in the kilosort style. This is ported from Nick Steinmetz's - `spikes` repository MATLAB code, https://github.com/cortex-lab/spikes. - - By default, a raster plot is drawn with the y-axis is spike depth and - x-axis is time. Optionally, a corresponding 2D activity histogram can be - added as a subplot (spatial bins, spike counts) with optional - peak coloring and drift event detection (see below). - - Parameters - ---------- - sorter_output : str | Path, - Path to the kilosort output folder. - only_include_large_amplitude_spikes : bool - If `True`, only spikes with larger amplitudes are included. For - details, see `_filter_large_amplitude_spikes()`. - decimate : None | int - If an integer n, only every nth spike is kept from the plot. Useful for improving - performance when there are many spikes. If `None`, spikes will not be decimated. - add_histogram_plot : bool - If `True`, an activity histogram will be added to a new subplot to the - left of the drift map. - add_histogram_peaks_and_boundaries : bool - If `True`, activity histogram peaks are detected and colored red if - isolated according to start/end boundaries of the peak (blue otherwise). - add_drift_events : bool - If `True`, drift events will be plot on the raster map. Required - `add_histogram_plot` and `add_histogram_peaks_and_boundaries` to run. - weight_histogram_by_amplitude : bool - If `True`, histogram counts will be weighted by spike amplitude. - localised_spikes_only : bool - If `True`, only spatially isolated spikes will be included. - exclude_noise : bool - If `True`, units labelled as noise in the `cluster_groups` file - will be excluded. - gain : float | None - If not `None`, amplitudes will be scaled by the supplied gain. - large_amplitude_only_segment_size: float - If `only_include_large_amplitude_spikes` is `True`, the probe is split into - segments to compute mean and std used as threshold. This sets the size of the - segments in um. - localised_spikes_channel_cutoff: int - If `localised_spikes_only` is `True`, spikes that have more than half of the - maximum loading channel over a range of > n channels are removed. - This sets the number of channels. - """ - - def __init__( - self, - sorter_output: str | Path, - only_include_large_amplitude_spikes: bool = True, - decimate: None | int = None, - add_histogram_plot: bool = False, - add_histogram_peaks_and_boundaries: bool = True, - add_drift_events: bool = True, - weight_histogram_by_amplitude: bool = False, - localised_spikes_only: bool = False, - exclude_noise: bool = False, - gain: float | None = None, - large_amplitude_only_segment_size: float = 800.0, - localised_spikes_channel_cutoff: int = 20, - ): - if not isinstance(sorter_output, Path): - sorter_output = Path(sorter_output) - - if not sorter_output.is_dir(): - raise ValueError(f"No output folder found at {sorter_output}") - - if not (sorter_output / "params.py").is_file(): - raise ValueError( - "The `sorting_output` path is not a valid kilosort output" - "folder. It does not contain a `params.py` file`." - ) - - plot_data = dict( - sorter_output=sorter_output, - only_include_large_amplitude_spikes=only_include_large_amplitude_spikes, - decimate=decimate, - add_histogram_plot=add_histogram_plot, - add_histogram_peaks_and_boundaries=add_histogram_peaks_and_boundaries, - add_drift_events=add_drift_events, - weight_histogram_by_amplitude=weight_histogram_by_amplitude, - localised_spikes_only=localised_spikes_only, - exclude_noise=exclude_noise, - gain=gain, - large_amplitude_only_segment_size=large_amplitude_only_segment_size, - localised_spikes_channel_cutoff=localised_spikes_channel_cutoff, - ) - BaseWidget.__init__(self, plot_data, backend="matplotlib") - - def plot_matplotlib(self, data_plot: dict, **unused_kwargs) -> None: - - dp = to_attr(data_plot) - - spike_times, spike_amplitudes, spike_depths, _ = self._compute_spike_amplitude_and_depth( - dp.sorter_output, dp.localised_spikes_only, dp.exclude_noise, dp.gain, dp.localised_spikes_channel_cutoff - ) - - # Calculate the amplitude range for plotting first, so the scale is always the - # same across all options (e.g. decimation) which helps with interpretability. - if dp.only_include_large_amplitude_spikes: - amplitude_range_all_spikes = ( - spike_amplitudes.min(), - spike_amplitudes.max(), - ) - else: - amplitude_range_all_spikes = np.percentile(spike_amplitudes, (1, 90)) - - if dp.decimate: - spike_times = spike_times[:: dp.decimate] - spike_amplitudes = spike_amplitudes[:: dp.decimate] - spike_depths = spike_depths[:: dp.decimate] - - if dp.only_include_large_amplitude_spikes: - spike_times, spike_amplitudes, spike_depths = self._filter_large_amplitude_spikes( - spike_times, spike_amplitudes, spike_depths, dp.large_amplitude_only_segment_size - ) - - # Setup axis and plot the raster drift map - fig = plt.figure(figsize=(10, 10 * (6 / 8))) - - if dp.add_histogram_plot: - gs = fig.add_gridspec(1, 2, width_ratios=[1, 5]) - hist_axis = fig.add_subplot(gs[0]) - raster_axis = fig.add_subplot(gs[1], sharey=hist_axis) - else: - raster_axis = fig.add_subplot() - - self._plot_kilosort_drift_map_raster( - spike_times, - spike_amplitudes, - spike_depths, - amplitude_range_all_spikes, - axis=raster_axis, - ) - - if not dp.add_histogram_plot: - raster_axis.set_xlabel("time") - raster_axis.set_ylabel("y position") - self.axes = [raster_axis] - return - - # If the histogram plot is requested, plot it alongside - # it's peak colouring, bounds display and drift point display. - hist_axis.set_xlabel("count") - raster_axis.set_xlabel("time") - hist_axis.set_ylabel("y position") - - bin_centers, counts = self._compute_activity_histogram( - spike_amplitudes, spike_depths, dp.weight_histogram_by_amplitude - ) - hist_axis.plot(counts, bin_centers, color="black", linewidth=1) - - if dp.add_histogram_peaks_and_boundaries: - drift_events = self._color_histogram_peaks_and_detect_drift_events( - spike_times, spike_depths, counts, bin_centers, hist_axis - ) - - if dp.add_drift_events and np.any(drift_events): - raster_axis.scatter(drift_events[:, 0], drift_events[:, 1], facecolors="r", edgecolors="none") - for i, _ in enumerate(drift_events): - raster_axis.text( - drift_events[i, 0] + 1, drift_events[i, 1], str(np.round(drift_events[i, 2])), color="r" - ) - self.axes = [hist_axis, raster_axis] - - def _plot_kilosort_drift_map_raster( - self, - spike_times: np.ndarray, - spike_amplitudes: np.ndarray, - spike_depths: np.ndarray, - amplitude_range: np.ndarray | tuple, - axis: matplotlib.axes.Axes, - ) -> None: - """ - Plot a drift raster plot in the kilosort style. - - This function was ported from Nick Steinmetz's `spikes` repository - MATLAB code, https://github.com/cortex-lab/spikes - - Parameters - ---------- - spike_times : np.ndarray - (num_spikes,) array of spike times. - spike_amplitudes : np.ndarray - (num_spikes,) array of corresponding spike amplitudes. - spike_depths : np.ndarray - (num_spikes,) array of corresponding spike depths. - amplitude_range : np.ndarray | tuple - (2,) array of min, max amplitude values for color binning. - axis : matplotlib.axes.Axes - Matplotlib axes object on which to plot the drift map. - """ - n_color_bins = 20 - marker_size = 0.5 - - color_bins = np.linspace(amplitude_range[0], amplitude_range[1], n_color_bins) - - colors = plt.get_cmap("gray")(np.linspace(0, 1, n_color_bins))[::-1] - - for bin_idx in range(n_color_bins - 1): - - spikes_in_amplitude_bin = np.logical_and( - spike_amplitudes >= color_bins[bin_idx], spike_amplitudes <= color_bins[bin_idx + 1] - ) - axis.scatter( - spike_times[spikes_in_amplitude_bin], - spike_depths[spikes_in_amplitude_bin], - color=colors[bin_idx], - s=marker_size, - antialiased=True, - ) - - def _compute_activity_histogram( - self, spike_amplitudes: np.ndarray, spike_depths: np.ndarray, weight_histogram_by_amplitude: bool - ) -> tuple[np.ndarray, ...]: - """ - Compute the activity histogram for the kilosort drift map's left-side plot. - - Parameters - ---------- - spike_amplitudes : np.ndarray - (num_spikes,) array of spike amplitudes. - spike_depths : np.ndarray - (num_spikes,) array of spike depths. - weight_histogram_by_amplitude : bool - If `True`, the spike amplitudes are taken into consideration when generating the - histogram. The amplitudes are scaled to the range [0, 1] then summed for each bin, - to generate the histogram values. If `False`, counts (i.e. num spikes per bin) - are used. - - Returns - ------- - bin_centers : np.ndarray - The spatial bin centers (probe depth) for the histogram. - values : np.ndarray - The histogram values. If `weight_histogram_by_amplitude` is `False`, these - values represent are counts, otherwise they are counts weighted by amplitude. - """ - assert ( - spike_amplitudes.dtype == np.float64 - ), "`spike amplitudes should be high precision as many values are summed." - - bin_um = 2 - bins = np.arange(spike_depths.min() - bin_um, spike_depths.max() + bin_um, bin_um) - values, bins = np.histogram(spike_depths, bins=bins) - bin_centers = (bins[:-1] + bins[1:]) / 2 - - if weight_histogram_by_amplitude: - bin_indices = np.digitize(spike_depths, bins, right=True) - 1 - values = np.zeros(bin_indices.max() + 1, dtype=np.float64) - scaled_spike_amplitudes = (spike_amplitudes - spike_amplitudes.min()) / np.ptp(spike_amplitudes) - np.add.at(values, bin_indices, scaled_spike_amplitudes) - - return bin_centers, values - - def _color_histogram_peaks_and_detect_drift_events( - self, - spike_times: np.ndarray, - spike_depths: np.ndarray, - counts: np.ndarray, - bin_centers: np.ndarray, - hist_axis: matplotlib.axes.Axes, - ) -> np.ndarray: - """ - Given an activity histogram, color the peaks red (isolated peak) or - blue (peak overlaps with other peaks) and compute spatial drift - events for isolated peaks across time bins. - - This function was ported from Nick Steinmetz's `spikes` repository - MATLAB code, https://github.com/cortex-lab/spikes - - Parameters - ---------- - spike_times : np.ndarray - (num_spikes,) array of spike times. - spike_depths : np.ndarray - (num_spikes,) array of corresponding spike depths. - counts : np.ndarray - (num_bins,) array of histogram bin counts. - bin_centers : np.ndarray - (num_bins,) array of histogram bin centers. - hist_axis : matplotlib.axes.Axes - Axes on which the histogram is plot, to add peaks. - - Returns - ------- - drift_events : np.ndarray - A (num_drift_events, 3) array of drift events. The columns are - (time_position, spatial_position, drift_value). The drift - value is computed per time, spatial bin as the difference between - the median position of spikes in the bin, and the bin center. - """ - all_peak_indexes = scipy.signal.find_peaks( - counts, - )[0] - - # Filter low-frequency peaks, so they are not included in the - # step to determine whether peaks are overlapping (new step - # introduced in the port to python) - bin_above_freq_threshold = counts[all_peak_indexes] > 0.3 * spike_times[-1] - filtered_peak_indexes = all_peak_indexes[bin_above_freq_threshold] - - drift_events = [] - for idx, peak_index in enumerate(filtered_peak_indexes): - - peak_count = counts[peak_index] - - # Find the start and end of peak min/max bounds (5% of amplitude) - start_position = np.where(counts[:peak_index] < peak_count * 0.05)[0].max() - end_position = np.where(counts[peak_index:] < peak_count * 0.05)[0].min() + peak_index - - if ( # bounds include another, different histogram peak - idx > 0 - and start_position < filtered_peak_indexes[idx - 1] - or idx < filtered_peak_indexes.size - 1 - and end_position > filtered_peak_indexes[idx + 1] - ): - hist_axis.scatter(peak_count, bin_centers[peak_index], facecolors="none", edgecolors="blue") - continue - - else: - for position in [start_position, end_position]: - hist_axis.axhline(bin_centers[position], 0, counts.max(), color="grey", linestyle="--") - hist_axis.scatter(peak_count, bin_centers[peak_index], facecolors="none", edgecolors="red") - - # For isolated histogram peaks, detect the drift events, defined as - # difference between spatial bin center and median spike depth in the bin - # over 6 um (in time / spatial bins with at least 10 spikes). - depth_in_window = np.logical_and( - spike_depths > bin_centers[start_position], - spike_depths < bin_centers[end_position], - ) - current_spike_depths = spike_depths[depth_in_window] - current_spike_times = spike_times[depth_in_window] - - window_s = 10 - - all_time_bins = np.arange(0, np.ceil(spike_times[-1]).astype(int), window_s) - for time_bin in all_time_bins: - - spike_in_time_bin = np.logical_and( - current_spike_times >= time_bin, current_spike_times <= time_bin + window_s - ) - drift_size = bin_centers[peak_index] - np.median(current_spike_depths[spike_in_time_bin]) - - # 6 um is the hardcoded threshold for drift, and we want at least 10 spikes for the median calculation - bin_has_drift = np.abs(drift_size) > 6 and np.sum(spike_in_time_bin, dtype=np.int16) > 10 - if bin_has_drift: - drift_events.append((time_bin + window_s / 2, bin_centers[peak_index], drift_size)) - - drift_events = np.array(drift_events) - - return drift_events - - def _compute_spike_amplitude_and_depth( - self, - sorter_output: str | Path, - localised_spikes_only, - exclude_noise, - gain: float | None, - localised_spikes_channel_cutoff: int, - ) -> tuple[np.ndarray, ...]: - """ - Compute the amplitude and depth of all detected spikes from the kilosort output. - - This function was ported from Nick Steinmetz's `spikes` repository - MATLAB code, https://github.com/cortex-lab/spikes - - Parameters - ---------- - sorter_output : str | Path - Path to the kilosort run sorting output. - localised_spikes_only : bool - If `True`, only spikes with small spatial footprint (i.e. 20 channels within 1/2 of the - amplitude of the maximum loading channel) and which are close to the average depth for - the cluster are returned. - gain: float | None - If a float provided, the `spike_amplitudes` will be scaled by this gain. - localised_spikes_channel_cutoff : int - If `localised_spikes_only` is `True`, spikes that have less than half of the - maximum loading channel over a range of n channels are removed. - This sets the number of channels. - - Returns - ------- - spike_times : np.ndarray - (num_spikes,) array of spike times. - spike_amplitudes : np.ndarray - (num_spikes,) array of corresponding spike amplitudes. - spike_depths : np.ndarray - (num_spikes,) array of corresponding depths (probe y-axis location). - - Notes - ----- - In `_template_positions_amplitudes` spike depths is calculated as simply the template - depth, for each spike (so it is the same for all spikes in a cluster). Here we need - to find the depth of each individual spike, using its low-dimensional projection. - `pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike. - Taking the first component, the subset of 32 channels associated with this - spike are indexed to get the actual channel locations (in um). Then, the channel - locations are weighted by their PC values. - """ - if isinstance(sorter_output, str): - sorter_output = Path(sorter_output) - - params = self._load_ks_dir(sorter_output, load_pcs=True, exclude_noise=exclude_noise) - - if localised_spikes_only: - localised_templates = [] - - for idx, template in enumerate(params["templates"]): - max_channel = np.max(np.abs(params["templates"][idx, :, :])) - channels_over_threshold = np.max(np.abs(params["templates"][idx, :, :]), axis=0) > 0.5 * max_channel - channel_ids_over_threshold = np.where(channels_over_threshold)[0] - - if np.ptp(channel_ids_over_threshold) <= localised_spikes_channel_cutoff: - localised_templates.append(idx) - - localised_template_by_spike = np.isin(params["spike_templates"], localised_templates) - - params["spike_templates"] = params["spike_templates"][localised_template_by_spike] - params["spike_times"] = params["spike_times"][localised_template_by_spike] - params["spike_clusters"] = params["spike_clusters"][localised_template_by_spike] - params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][localised_template_by_spike] - params["pc_features"] = params["pc_features"][localised_template_by_spike] - - # Compute spike depths - pc_features = params["pc_features"][:, 0, :] - pc_features[pc_features < 0] = 0 - - # Get the channel indexes corresponding to the 32 channels from the PC. - spike_features_indices = params["pc_features_indices"][params["spike_templates"], :] - - ycoords = params["channel_positions"][:, 1] - spike_feature_ycoords = ycoords[spike_features_indices] - - spike_depths = np.sum(spike_feature_ycoords * pc_features**2, axis=1) / np.sum(pc_features**2, axis=1) - - # Compute amplitudes, scale if required and drop un-localised spikes before returning. - spike_amplitudes, _, _, _, unwhite_templates, *_ = self._template_positions_amplitudes( - params["templates"], - params["whitening_matrix_inv"], - ycoords, - params["spike_templates"], - params["temp_scaling_amplitudes"], - ) - - if gain is not None: - spike_amplitudes *= gain - - max_site = np.argmax(np.max(np.abs(unwhite_templates), axis=1), axis=1) - spike_sites = max_site[params["spike_templates"]] - - if localised_spikes_only: - # Interpolate the channel ids to location. - # Remove spikes > 5 um from average position - # Above we already removed non-localized templates, but that on its own is insufficient. - # Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient - b = stats.linregress(spike_depths, spike_sites).slope - i = np.abs(spike_sites - b * spike_depths) <= 5 - - params["spike_times"] = params["spike_times"][i] - spike_amplitudes = spike_amplitudes[i] - spike_depths = spike_depths[i] - - return params["spike_times"], spike_amplitudes, spike_depths, spike_sites - - def _filter_large_amplitude_spikes( - self, - spike_times: np.ndarray, - spike_amplitudes: np.ndarray, - spike_depths: np.ndarray, - large_amplitude_only_segment_size, - ) -> tuple[np.ndarray, ...]: - """ - Return spike properties with only the largest-amplitude spikes included. The probe - is split into egments, and within each segment the mean and std computed. - Any spike less than 1.5x the standard deviation in amplitude of it's segment is excluded - Splitting the probe is only done for the exclusion step, the returned array are flat. - - Takes as input arrays `spike_times`, `spike_depths` and `spike_amplitudes` and returns - copies of these arrays containing only the large amplitude spikes. - """ - spike_bool = np.zeros_like(spike_amplitudes, dtype=bool) - - segment_size_um = large_amplitude_only_segment_size - probe_segments_left_edges = np.arange(np.floor(spike_depths.max() / segment_size_um) + 1) * segment_size_um - - for segment_left_edge in probe_segments_left_edges: - segment_right_edge = segment_left_edge + segment_size_um - - spikes_in_seg = np.where( - np.logical_and(spike_depths >= segment_left_edge, spike_depths < segment_right_edge) - )[0] - spike_amps_in_seg = spike_amplitudes[spikes_in_seg] - is_high_amplitude = spike_amps_in_seg > np.mean(spike_amps_in_seg) + 1.5 * np.std(spike_amps_in_seg, ddof=1) - - spike_bool[spikes_in_seg] = is_high_amplitude - - spike_times = spike_times[spike_bool] - spike_amplitudes = spike_amplitudes[spike_bool] - spike_depths = spike_depths[spike_bool] - - return spike_times, spike_amplitudes, spike_depths - - - diff --git a/driftmap_viewer/driftmapviewer_new.py b/driftmap_viewer/driftmapviewer_new.py deleted file mode 100644 index 1e270c9..0000000 --- a/driftmap_viewer/driftmapviewer_new.py +++ /dev/null @@ -1,334 +0,0 @@ - -from pathlib import Path -import matplotlib.axis -import scipy.signal -from spikeinterface.core import read_python -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt - -import matplotlib.pyplot as plt -from scipy import stats - - -def get_spikes_info(sorter_output, ks_version, exclude_noise): - """""" - if ks_version == "kilosort4": - spike_times, spike_amplitudes, spike_depths = kilosort_4.get_spikes_info_ks4( - sorter_output, exclude_noise - ) - else: - spike_times, spike_amplitudes, spike_depths = kilosort1_3.get_spike_info( - sorter_output, exclude_noise - ) - - return spike_times, spike_amplitudes, spike_depths - - - -# TODO: GAIN DOES NOTHING -# separate function to add histogram - -def get_drift_map_plot( - sorter_output: str | Path, - only_include_large_amplitude_spikes: bool = True, - decimate: None | int = None, - add_histogram_plot: bool = False, - weight_histogram_by_amplitude: bool = False, - exclude_noise: bool = False, - large_amplitude_only_segment_size: float = 800.0, -): - """ - """ - # ks_version = get_ks_version() - sorter_output = Path(sorter_output) - - ks_version = "kilosort4" - - spike_times, spike_amplitudes, spike_depths = get_spikes_info( - sorter_output, ks_version, exclude_noise - ) - - # TODO: this is super weird, can be imrpoved? - if log_transform_amplitudes: - spike_amplitudes = np.log10(spike_amplitudes) - - # Calculate the amplitude range for plotting first, so the scale is always the - # same across all options (e.g. decimation) which helps with interpretability. - amplitude_range_all_spikes = ( - spike_amplitudes.min(), - spike_amplitudes.max(), - ) - - if decimate: - spike_times = spike_times[:: decimate] - spike_amplitudes = spike_amplitudes[:: decimate] - spike_depths = spike_depths[:: decimate] - - if only_include_large_amplitude_spikes: - spike_times, spike_amplitudes, spike_depths = _filter_large_amplitude_spikes( - spike_times, spike_amplitudes, spike_depths, - large_amplitude_only_segment_size - ) - - # Setup axis and plot the raster drift map - fig = plt.figure(figsize=(10, 10 * (6 / 8))) - - if add_histogram_plot: - gs = fig.add_gridspec(1, 2, width_ratios=[1, 5]) - hist_axis = fig.add_subplot(gs[0]) - raster_axis = fig.add_subplot(gs[1], sharey=hist_axis) - else: - raster_axis = fig.add_subplot() - - print(amplitude_range_all_spikes) - _plot_kilosort_drift_map_raster( - spike_times, - spike_amplitudes, - spike_depths, - amplitude_range_all_spikes, - axis=raster_axis, - ) - - if not add_histogram_plot: - raster_axis.set_xlabel("time") - raster_axis.set_ylabel("y position") - return fig - - # If the histogram plot is requested, plot it alongside - # it's peak colouring, bounds display and drift point display. - hist_axis.set_xlabel("count") - raster_axis.set_xlabel("time") - hist_axis.set_ylabel("y position") - - bin_centers, counts = _compute_activity_histogram( - spike_amplitudes, spike_depths, weight_histogram_by_amplitude - ) - hist_axis.plot(counts, bin_centers, color="black", linewidth=1) - - return fig - -# TODO: not sure about this, much point in computing as segments? Just combine? Its just another parameter to track... -# hmm I guess it makes sense if going through brain regions, and if want to do the entire thing, you can set segment to probe... -def _filter_large_amplitude_spikes( - spike_times: np.ndarray, - spike_amplitudes: np.ndarray, - spike_depths: np.ndarray, - large_amplitude_only_segment_size, -) -> tuple[np.ndarray, ...]: - """ - Return spike properties with only the largest-amplitude spikes included. The probe - is split into segments, and within each segment the mean and std computed. - Any spike less than 1.5x the standard deviation in amplitude of it's segment is excluded - Splitting the probe is only done for the exclusion step, the returned array are flat. - - Takes as input arrays `spike_times`, `spike_depths` and `spike_amplitudes` and returns - copies of these arrays containing only the large amplitude spikes. - """ - spike_bool = np.zeros_like(spike_amplitudes, dtype=bool) - - segment_size_um = large_amplitude_only_segment_size - probe_segments_left_edges = np.arange(np.floor(spike_depths.max() / segment_size_um) + 1) * segment_size_um - - for segment_left_edge in probe_segments_left_edges: - segment_right_edge = segment_left_edge + segment_size_um - - spikes_in_seg = np.where( - np.logical_and(spike_depths >= segment_left_edge, spike_depths < segment_right_edge) - )[0] - spike_amps_in_seg = spike_amplitudes[spikes_in_seg] - is_high_amplitude = spike_amps_in_seg > np.mean(spike_amps_in_seg) + 1.5 * np.std(spike_amps_in_seg, ddof=1) - - spike_bool[spikes_in_seg] = is_high_amplitude - - spike_times = spike_times[spike_bool] - spike_amplitudes = spike_amplitudes[spike_bool] - spike_depths = spike_depths[spike_bool] - - return spike_times, spike_amplitudes, spike_depths - -def _plot_kilosort_drift_map_raster( - spike_times: np.ndarray, - spike_amplitudes: np.ndarray, - spike_depths: np.ndarray, - amplitude_range: np.ndarray | tuple, - axis: matplotlib.axes.Axes, -) -> None: - """ - Plot a drift raster plot in the kilosort style. - - This function was ported from Nick Steinmetz's `spikes` repository - MATLAB code, https://github.com/cortex-lab/spikes - - Parameters - ---------- - spike_times : np.ndarray - (num_spikes,) array of spike times. - spike_amplitudes : np.ndarray - (num_spikes,) array of corresponding spike amplitudes. - spike_depths : np.ndarray - (num_spikes,) array of corresponding spike depths. - amplitude_range : np.ndarray | tuple - (2,) array of min, max amplitude values for color binning. - axis : matplotlib.axes.Axes - Matplotlib axes object on which to plot the drift map. - """ - n_color_bins = 20 - marker_size = 0.5 - - color_bins = np.linspace(amplitude_range[0], amplitude_range[1], n_color_bins) - - colors = plt.get_cmap("gray")(np.linspace(0, 1, n_color_bins))[::-1] - - for bin_idx in range(n_color_bins - 1): - - spikes_in_amplitude_bin = np.logical_and( - spike_amplitudes >= color_bins[bin_idx], spike_amplitudes <= color_bins[bin_idx + 1] - ) - axis.scatter( - spike_times[spikes_in_amplitude_bin], - spike_depths[spikes_in_amplitude_bin], - color=colors[bin_idx], - s=marker_size, - antialiased=True, - ) - -def _compute_activity_histogram( - spike_amplitudes: np.ndarray, spike_depths: np.ndarray, weight_histogram_by_amplitude: bool -) -> tuple[np.ndarray, ...]: - """ - Compute the activity histogram for the kilosort drift map's left-side plot. - - Parameters - ---------- - spike_amplitudes : np.ndarray - (num_spikes,) array of spike amplitudes. - spike_depths : np.ndarray - (num_spikes,) array of spike depths. - weight_histogram_by_amplitude : bool - If `True`, the spike amplitudes are taken into consideration when generating the - histogram. The amplitudes are scaled to the range [0, 1] then summed for each bin, - to generate the histogram values. If `False`, counts (i.e. num spikes per bin) - are used. - - Returns - ------- - bin_centers : np.ndarray - The spatial bin centers (probe depth) for the histogram. - values : np.ndarray - The histogram values. If `weight_histogram_by_amplitude` is `False`, these - values represent are counts, otherwise they are counts weighted by amplitude. - """ - spike_amplitudes = spike_amplitudes.astype(np.float64) - - bin_um = 2 - bins = np.arange(spike_depths.min() - bin_um, spike_depths.max() + bin_um, bin_um) - values, bins = np.histogram(spike_depths, bins=bins) - bin_centers = (bins[:-1] + bins[1:]) / 2 - - if weight_histogram_by_amplitude: - bin_indices = np.digitize(spike_depths, bins, right=True) - 1 - values = np.zeros(bin_indices.max() + 1, dtype=np.float64) - scaled_spike_amplitudes = (spike_amplitudes - spike_amplitudes.min()) / np.ptp(spike_amplitudes) - np.add.at(values, bin_indices, scaled_spike_amplitudes) - - return bin_centers, values - - -def _color_histogram_peaks_and_detect_drift_events( - spike_times: np.ndarray, - spike_depths: np.ndarray, - counts: np.ndarray, - bin_centers: np.ndarray, - hist_axis: matplotlib.axes.Axes, -) -> np.ndarray: - """ - Given an activity histogram, color the peaks red (isolated peak) or - blue (peak overlaps with other peaks) and compute spatial drift - events for isolated peaks across time bins. - - This function was ported from Nick Steinmetz's `spikes` repository - MATLAB code, https://github.com/cortex-lab/spikes - - Parameters - ---------- - spike_times : np.ndarray - (num_spikes,) array of spike times. - spike_depths : np.ndarray - (num_spikes,) array of corresponding spike depths. - counts : np.ndarray - (num_bins,) array of histogram bin counts. - bin_centers : np.ndarray - (num_bins,) array of histogram bin centers. - hist_axis : matplotlib.axes.Axes - Axes on which the histogram is plot, to add peaks. - - Returns - ------- - drift_events : np.ndarray - A (num_drift_events, 3) array of drift events. The columns are - (time_position, spatial_position, drift_value). The drift - value is computed per time, spatial bin as the difference between - the median position of spikes in the bin, and the bin center. - """ - all_peak_indexes = scipy.signal.find_peaks( - counts, - )[0] - - # Filter low-frequency peaks, so they are not included in the - # step to determine whether peaks are overlapping (new step - # introduced in the port to python) - bin_above_freq_threshold = counts[all_peak_indexes] > 0.3 * spike_times[-1] - filtered_peak_indexes = all_peak_indexes[bin_above_freq_threshold] - - drift_events = [] - for idx, peak_index in enumerate(filtered_peak_indexes): - - peak_count = counts[peak_index] - - # Find the start and end of peak min/max bounds (5% of amplitude) - start_position = np.where(counts[:peak_index] < peak_count * 0.05)[0].max() - end_position = np.where(counts[peak_index:] < peak_count * 0.05)[0].min() + peak_index - - if ( # bounds include another, different histogram peak - idx > 0 - and start_position < filtered_peak_indexes[idx - 1] - or idx < filtered_peak_indexes.size - 1 - and end_position > filtered_peak_indexes[idx + 1] - ): - hist_axis.scatter(peak_count, bin_centers[peak_index], facecolors="none", edgecolors="blue") - continue - - else: - for position in [start_position, end_position]: - hist_axis.axhline(bin_centers[position], 0, counts.max(), color="grey", linestyle="--") - hist_axis.scatter(peak_count, bin_centers[peak_index], facecolors="none", edgecolors="red") - - # For isolated histogram peaks, detect the drift events, defined as - # difference between spatial bin center and median spike depth in the bin - # over 6 um (in time / spatial bins with at least 10 spikes). - depth_in_window = np.logical_and( - spike_depths > bin_centers[start_position], - spike_depths < bin_centers[end_position], - ) - current_spike_depths = spike_depths[depth_in_window] - current_spike_times = spike_times[depth_in_window] - - window_s = 10 - - all_time_bins = np.arange(0, np.ceil(spike_times[-1]).astype(int), window_s) - for time_bin in all_time_bins: - - spike_in_time_bin = np.logical_and( - current_spike_times >= time_bin, current_spike_times <= time_bin + window_s - ) - drift_size = bin_centers[peak_index] - np.median(current_spike_depths[spike_in_time_bin]) - - # 6 um is the hardcoded threshold for drift, and we want at least 10 spikes for the median calculation - bin_has_drift = np.abs(drift_size) > 6 and np.sum(spike_in_time_bin, dtype=np.int16) > 10 - if bin_has_drift: - drift_events.append((time_bin + window_s / 2, bin_centers[peak_index], drift_size)) - - drift_events = np.array(drift_events) - - return drift_events \ No newline at end of file diff --git a/__init__.py b/driftmap_viewer/extractors/__init__.py similarity index 100% rename from __init__.py rename to driftmap_viewer/extractors/__init__.py diff --git a/driftmap_viewer/extractors/analyzer_helpers.py b/driftmap_viewer/extractors/analyzer_helpers.py new file mode 100644 index 0000000..0760bef --- /dev/null +++ b/driftmap_viewer/extractors/analyzer_helpers.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import warnings + +import numpy as np +import spikeinterface as si + + +def get_sorting_analyzer( + analyzer: si.SortingAnalyzer, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + Get the required data from the sorting analyzer. Note that this + will not get all detected spikes, but rather the number of spikes + specified when creating the analyzer, `max_spikes_per_unit`. + """ + random_spike_indices = analyzer.get_extension("random_spikes").data[ + "random_spikes_indices" + ] + spike_vector = analyzer.sorting.to_spike_vector() + spike_times = ( + spike_vector["sample_index"][random_spike_indices] + / analyzer.sorting.get_sampling_frequency() + ) + spike_amplitudes = np.abs( + analyzer.get_extension("spike_amplitudes").data["amplitudes"] + ) + + spike_depths = analyzer.get_extension("spike_locations").data["spike_locations"][ + "y" + ] + spike_clusters = spike_vector["unit_index"][random_spike_indices] + + # Get the templates, assume only one method was used. If multiple + # methods were used, use the first and throw a warning. If people + # want this exposed, it can be exposed, but at the moment seems too particular. + templates_dict = analyzer.get_extension("templates").data + all_template_keys = templates_dict.keys() + template_key = list(all_template_keys)[0] + + if len(all_template_keys) != 1: + warnings.warn( + f"Multiple template calculation methods detected. Using {template_key}" + ) + + templates = analyzer.get_extension("templates").data[template_key] + channel_locations = analyzer.get_channel_locations() + + return ( + spike_times, + spike_amplitudes, + spike_depths, + spike_clusters, + templates, + channel_locations, + ) + + +def get_noise_mask( + exclude_noise: bool | str, + spike_clusters: np.ndarray, + analyzer: si.SortingAnalyzer, +) -> np.ndarray: + """ """ + if exclude_noise is True: + raise ValueError( + f"When using SortingAnalyzer, `exclude_noise` must be a string of the " + f"name of the labels to use, passed to `analyzer.get_sorting_property()." + f"Properties on this analyzer are: {analyzer.sorting.get_property_keys()}" + ) + + assert isinstance(exclude_noise, str), "`exclude_noise` must be a string" + labels = analyzer.get_sorting_property(exclude_noise) + noise_mask = (labels == "noise")[spike_clusters] # TODO: make sure to test this + + return noise_mask diff --git a/driftmap_viewer/kilosort1_3.py b/driftmap_viewer/extractors/kilosort1_3.py similarity index 65% rename from driftmap_viewer/kilosort1_3.py rename to driftmap_viewer/extractors/kilosort1_3.py index a4b8da9..df6fca6 100644 --- a/driftmap_viewer/kilosort1_3.py +++ b/driftmap_viewer/extractors/kilosort1_3.py @@ -1,11 +1,16 @@ from __future__ import annotations + from typing import TYPE_CHECKING + if TYPE_CHECKING: from pathlib import Path + import numpy as np import numpy as np from spikeinterface.core import read_python +from driftmap_viewer.extractors import kilosort_helpers + def get_spikes_info_ks1_3( sorter_output: str | Path, @@ -32,8 +37,10 @@ def get_spikes_info_ks1_3( Notes ----- - In `_template_positions_amplitudes` spike depths is calculated as simply the template - depth, for each spike (so it is the same for all spikes in a cluster). Here we need + In `_template_positions_amplitudes` spike depths is + calculated as simply the template depth, for each spike + (so it is the same for all spikes in a cluster). Here + we need to find the depth of each individual spike, using its low-dimensional projection. `pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike. Taking the first component, the subset of 32 channels associated with this @@ -48,33 +55,40 @@ def get_spikes_info_ks1_3( # Compute spike depths pc_features = params["pc_features"][:, 0, :] pc_features[pc_features < 0] = 0 + pc_features = pc_features**2 # Get the channel indexes corresponding to the 32 channels from the PC. - spike_features_indices = params["pc_features_indices"][params["spike_templates"], :] + spike_features_indices = params["pc_features_indices"][params["spike_clusters"], :] ycoords = params["channel_positions"][:, 1] spike_feature_ycoords = ycoords[spike_features_indices] - spike_depths = np.sum(spike_feature_ycoords * pc_features ** 2, axis=1) / np.sum( - pc_features ** 2, axis=1) + # TODO: document this, it's from Nick Steinmetz, or phy + spike_depths = np.sum(spike_feature_ycoords * pc_features, axis=1) / np.sum( + pc_features, axis=1 + ) - # Compute amplitudes, scale if required and drop un-localised spikes before returning. - spike_amplitudes, _, _, _, unwhite_templates, *_ = _template_positions_amplitudes( + spike_amplitudes, white_templates = _template_positions_amplitudes( params["templates"], params["whitening_matrix_inv"], - ycoords, - params["spike_templates"], + params["spike_clusters"], params["temp_scaling_amplitudes"], ) - return params["spike_times"], spike_amplitudes, spike_depths + return ( + params["spike_times"], + spike_amplitudes, + spike_depths, + params["spike_clusters"].squeeze(), + white_templates, + params["channel_positions"], + ) def _template_positions_amplitudes( templates: np.ndarray, inverse_whitening_matrix: np.ndarray, - ycoords: np.ndarray, - spike_templates: np.ndarray, + spike_clusters: np.ndarray, template_scaling_amplitudes: np.ndarray, ) -> tuple[np.ndarray, ...]: """ @@ -92,7 +106,7 @@ def _template_positions_amplitudes( unwhiten templates. ycoords : np.ndarray (num_channels,) array of the y-axis (depth) channel positions. - spike_templates : np.ndarray + spike_clusters : np.ndarray (num_spikes,) array indicating the template associated with each spike. template_scaling_amplitudes : np.ndarray (num_spikes,) array holding the scaling amplitudes, by which the @@ -106,16 +120,8 @@ def _template_positions_amplitudes( (num_spikes,) array of the depth (probe y-axis) of each spike. Note this is just the template depth for each spike (i.e. depth of all spikes from the same cluster are identical). - template_amplitudes : np.ndarray - (num_templates,) Amplitude of each template, calculated as average of spike amplitudes. - template_depths : np.ndarray - (num_templates,) array of the depth of each template. - unwhite_templates : np.ndarray - Unwhitened templates (num_clusters, num_samples, num_channels). - trough_peak_durations : np.ndarray - (num_templates, ) array of durations from trough to peak for each template waveform - waveforms : np.ndarray - (num_templates, num_samples) Waveform of each template, taken as the signal on the maximum loading channel. + white_templates : np.ndarray + Whitened templates (num_clusters, num_samples, num_channels). """ # Unwhiten the template waveforms unwhite_templates = np.zeros_like(templates) @@ -127,62 +133,30 @@ def _template_positions_amplitudes( # Take the max amplitude for each channel, then use the channel # with most signal as template amplitude. Zero any small channel amplitudes. - template_amplitudes_per_channel = np.max(unwhite_templates, axis=1) - np.min(unwhite_templates, axis=1) + template_amplitudes_per_channel = np.max(unwhite_templates, axis=1) - np.min( + unwhite_templates, axis=1 + ) template_amplitudes_unscaled = np.max(template_amplitudes_per_channel, axis=1) threshold_values = 0.3 * template_amplitudes_unscaled - template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0 - - # Calculate the template depth as the center of mass based on channel amplitudes - template_depths = np.sum(template_amplitudes_per_channel * ycoords[np.newaxis, :], axis=1) / np.sum( - template_amplitudes_per_channel, axis=1 - ) + template_amplitudes_per_channel[ + template_amplitudes_per_channel < threshold_values[:, np.newaxis] + ] = 0 # Next, find the depth of each spike based on its template. Recompute the template # amplitudes as the average of the spike amplitudes ('since # tempScalingAmps are equal mean for all templates') - spike_amplitudes = template_amplitudes_unscaled[spike_templates] * template_scaling_amplitudes - - # Take the average of all spike amplitudes to get actual template amplitudes - # (since tempScalingAmps are equal mean for all templates) - num_indices = templates.shape[0] - sum_per_index = np.zeros(num_indices, dtype=np.float64) - np.add.at(sum_per_index, spike_templates, spike_amplitudes) - counts = np.bincount(spike_templates, minlength=num_indices) - template_amplitudes = np.divide(sum_per_index, counts, out=np.zeros_like(sum_per_index), where=counts != 0) - - # Each spike's depth is the depth of its template - spike_depths = template_depths[spike_templates] - - # Get channel with the largest amplitude (take that as the waveform) - max_site = np.argmax(np.max(np.abs(templates), axis=1), axis=1) - - # Use template channel with max signal as waveform - waveforms = np.empty(templates.shape[:2]) - for idx, template in enumerate(templates): - waveforms[idx, :] = templates[idx, :, max_site[idx]] - - # Get trough-to-peak time for each template. Find the trough as the - # minimum signal for the template waveform. The duration (in - # samples) is the num samples from trough to the largest value - # following the trough. - waveform_trough = np.argmin(waveforms, axis=1) - - trough_peak_durations = np.zeros(waveforms.shape[0]) - for idx, tmp_max in enumerate(waveforms): - trough_peak_durations[idx] = np.argmax(tmp_max[waveform_trough[idx] :]) + spike_amplitudes = ( + template_amplitudes_unscaled[spike_clusters] * template_scaling_amplitudes + ) return ( spike_amplitudes, - spike_depths, - template_depths, - template_amplitudes, - unwhite_templates, - trough_peak_durations, - waveforms, + templates, ) + def _load_ks_dir(sorter_output: Path, load_pcs: bool = False) -> dict: """ Loads the output of Kilosort into a `params` dict. @@ -218,7 +192,7 @@ def _load_ks_dir(sorter_output: Path, load_pcs: bool = False) -> dict: params = read_python(sorter_output / "params.py") spike_times = np.load(sorter_output / "spike_times.npy") / params["sample_rate"] - spike_templates = np.load(sorter_output / "spike_templates.npy") + spike_clusters = kilosort_helpers.load_spike_clusters(sorter_output) temp_scaling_amplitudes = np.load(sorter_output / "amplitudes.npy") @@ -230,7 +204,7 @@ def _load_ks_dir(sorter_output: Path, load_pcs: bool = False) -> dict: new_params = { "spike_times": spike_times.squeeze(), - "spike_templates": spike_templates.squeeze(), + "spike_clusters": spike_clusters.squeeze(), "pc_features": pc_features, "pc_features_indices": pc_features_indices, "temp_scaling_amplitudes": temp_scaling_amplitudes.squeeze(), diff --git a/driftmap_viewer/extractors/kilosort4.py b/driftmap_viewer/extractors/kilosort4.py new file mode 100644 index 0000000..2fabf55 --- /dev/null +++ b/driftmap_viewer/extractors/kilosort4.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + +import numpy as np + +from driftmap_viewer.extractors import kilosort_helpers + + +def compute_spike_amplitudes( + templates: np.ndarray, + spike_clusters: np.ndarray, + amplitudes: np.ndarray, +) -> np.ndarray: + # This is based on https://github.com/MouseLand/Kilosort/issues/804, + # need to double check it + template_ptp = np.max(templates, axis=1) - np.min(templates, axis=1) + template_max_peaks = np.max(template_ptp, axis=1) + spike_amplitudes = template_max_peaks[spike_clusters] * amplitudes + return spike_amplitudes + + +def get_spikes_info_ks4( + sorter_output: Path, +) -> tuple[np.ndarray, ...]: + + # TODO: kept_spikes is not always the same size as other spike data + # Currently not used for loading + # kept_spikes = np.load(sorter_output / "kept_spikes.npy") + + spike_times = np.load(sorter_output / "spike_times.npy") + amplitudes = np.load(sorter_output / "amplitudes.npy") + spike_depths = np.load(sorter_output / "spike_positions.npy")[:, 1] + spike_clusters = kilosort_helpers.load_spike_clusters(sorter_output) + + templates = np.load(sorter_output / "templates.npy") # rename unwhiten? + channel_positions = np.load(sorter_output / "channel_positions.npy") + spike_amplitudes = compute_spike_amplitudes(templates, spike_clusters, amplitudes) + + return ( + spike_times, + spike_amplitudes, + spike_depths, + spike_clusters, + templates, + channel_positions, + ) diff --git a/driftmap_viewer/extractors/kilosort_helpers.py b/driftmap_viewer/extractors/kilosort_helpers.py new file mode 100644 index 0000000..5338402 --- /dev/null +++ b/driftmap_viewer/extractors/kilosort_helpers.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pandas as pd + + +def load_cluster_groups(cluster_path: Path) -> tuple[np.ndarray, np.ndarray]: + """Load kilosort ``cluster_groups`` file. + + Contains a table of quality assignments, one per unit. These can be + "noise", "mua", "good" or "unsorted". + + There are slight formatting differences between the ``.tsv`` and + ``.csv`` versions, presumably from different kilosort versions. + + This function was ported from Nick Steinmetz's ``spikes`` repository + MATLAB code, https://github.com/cortex-lab/spikes + + Parameters + ---------- + cluster_path + The full filepath to the ``cluster_groups`` tsv or csv file. + + Returns + ------- + cluster_ids + (num_clusters,) Array of (integer) unit IDs. + cluster_groups + (num_clusters,) Array of (integer) unit quality assignments, see + code below for mapping to "noise", "mua", "good" and "unsorted". + """ + cluster_groups_table = pd.read_csv(cluster_path, sep="\t") + + group_key = cluster_groups_table.columns[1] # "groups" (csv) or "KSLabel" (tsv) + + for key, _id in zip( + ["noise", "mua", "good", "unsorted"], + ["0", "1", "2", "3"], + # required as str to avoid pandas replace downcast FutureWarning + ): + cluster_groups_table[group_key] = cluster_groups_table[group_key].replace( + key, _id + ) + + cluster_ids = cluster_groups_table["cluster_id"].to_numpy() + cluster_groups = cluster_groups_table[group_key].astype(int).to_numpy() + + return cluster_ids, cluster_groups + + +def get_noise_mask(spike_clusters: np.ndarray, sorter_output: Path) -> np.ndarray: + """Build a boolean mask identifying spikes that belong to noise clusters. + + Loads the cluster-groups file (``cluster_groups.csv`` or + ``cluster_group.tsv``) from the sorter output directory. Spikes + whose cluster is labelled *noise* (group == 0) are marked ``True``. + + Parameters + ---------- + spike_clusters + (num_spikes,) cluster assignment per spike. + sorter_output + Path to the Kilosort sorter output directory. + + Returns + ------- + np.ndarray + (num_spikes,) boolean array — ``True`` for spikes belonging to + a noise cluster. + + Raises + ------ + ValueError + If neither ``cluster_groups.csv`` nor ``cluster_group.tsv`` + exists in ``sorter_output``. + """ + if not ( + (cluster_path := sorter_output / "cluster_groups.csv").is_file() + or (cluster_path := sorter_output / "cluster_group.tsv").is_file() + ): + raise ValueError( + f"`exclude_noise` is `True` but there is no `cluster_groups.csv/.tsv` " + f"in the sorting output at: {sorter_output}" + ) + + cluster_ids, cluster_groups = load_cluster_groups(cluster_path) + + noise_cluster_ids = cluster_ids[cluster_groups == 0] + + exclude_bool_mask = np.isin(spike_clusters.ravel(), noise_cluster_ids) + + return exclude_bool_mask + + +def get_ks_version(sorter_path: Path) -> str: + """ """ + log_file = list(sorter_path.glob("kilosort*.log")) + assert len(log_file) == 1 + + return Path(log_file[0]).name.split(".")[0] + + +def load_spike_clusters(sorter_path: Path) -> np.ndarray: + if (path_ := sorter_path / "spike_clusters.npy").is_file(): + spike_clusters = np.load(path_) + else: + spike_clusters = np.load(sorter_path / "spike_templates.npy") + return spike_clusters diff --git a/driftmap_viewer/helpers.py b/driftmap_viewer/helpers.py deleted file mode 100644 index 2b9ad2b..0000000 --- a/driftmap_viewer/helpers.py +++ /dev/null @@ -1,78 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from pathlib import Path -import pandas as pd -import numpy as np - -def load_cluster_groups(cluster_path: Path) -> tuple[np.ndarray, ...]: - """ - Load kilosort `cluster_groups` file, that contains a table of - quality assignments, one per unit. These can be "noise", "mua", "good" - or "unsorted". - - There is some slight formatting differences between the `.tsv` and `.csv` - versions, presumably from different kilosort versions. - - This function was ported from Nick Steinmetz's `spikes` repository MATLAB code, - https://github.com/cortex-lab/spikes - - Parameters - ---------- - cluster_path : Path - The full filepath to the `cluster_groups` tsv or csv file. - - Returns - ------- - cluster_ids : np.ndarray - (num_clusters,) Array of (integer) unit IDs. - - cluster_groups : np.ndarray - (num_clusters,) Array of (integer) unit quality assignments, see code - below for mapping to "noise", "mua", "good" and "unsorted". - """ - cluster_groups_table = pd.read_csv(cluster_path, sep="\t") - - group_key = cluster_groups_table.columns[1] # "groups" (csv) or "KSLabel" (tsv) - - for key, _id in zip( - ["noise", "mua", "good", "unsorted"], - ["0", "1", "2", "3"], - # required as str to avoid pandas replace downcast FutureWarning - ): - cluster_groups_table[group_key] = cluster_groups_table[group_key].replace(key, _id) - - cluster_ids = cluster_groups_table["cluster_id"].to_numpy() - cluster_groups = cluster_groups_table[group_key].astype(int).to_numpy() - - return cluster_ids, cluster_groups - -# This is such a jankily written function fix it -def exclude_noise(sorter_output, spike_times, spike_amplitudes, spike_depths): - """""" - if (cluster_path := sorter_output / "spike_clusters.npy").is_file(): # TODO: this can be csv?!?!? - spike_clusters = np.load(cluster_path) - else: - # this is a pain to have here, I don't think this case is realistic. - raise NotImplementedError("spike clusters.csv does not exist. Under what circumstance is this? probably very old.") - # spike_clusters = spike_templates.copy() - - if ( # short circuit ensures cluster_path is assigned appropriately - (cluster_path := sorter_output / "cluster_groups.csv").is_file() - or (cluster_path := sorter_output / "cluster_group.tsv").is_file() - ): - cluster_ids, cluster_groups = load_cluster_groups(cluster_path) - - noise_cluster_ids = cluster_ids[cluster_groups == 0] - not_noise_clusters_by_spike = ~np.isin(spike_clusters.ravel(), - noise_cluster_ids) - spike_times = spike_times[not_noise_clusters_by_spike] - spike_amplitudes = spike_amplitudes[not_noise_clusters_by_spike] - spike_depths = spike_depths[not_noise_clusters_by_spike] - - return spike_times, spike_amplitudes, spike_depths - - raise ValueError( - f"`exclude_noise` is `True` but there is no `cluster_groups.csv` or `.tsv` " - f"in the sorting output at: {sorter_output}" - ) diff --git a/driftmap_viewer/interactive/driftmap_plot_widget.py b/driftmap_viewer/interactive/driftmap_plot_widget.py index 141b2b7..4d14f54 100644 --- a/driftmap_viewer/interactive/driftmap_plot_widget.py +++ b/driftmap_viewer/interactive/driftmap_plot_widget.py @@ -1,7 +1,12 @@ +from __future__ import annotations + +from typing import Any + import numpy as np import pyqtgraph as pg -from pyqtgraph.Qt import QtWidgets, QtCore -import matplotlib.pyplot as plt +from pyqtgraph.Qt import QtCore, QtWidgets + +from driftmap_viewer.data_model import DataModel pg.setConfigOption("background", "w") pg.setConfigOption("foreground", "k") @@ -9,24 +14,20 @@ class DriftmapPlotWidget(QtWidgets.QWidget): - """ - """ - def __init__(self, spike_times, spike_amplitudes, spike_depths, - spike_templates, templates, channel_positions, - amplitude_scaling="linear", n_color_bins=20, - point_size=5.0, sorter_path=None): + """ """ + + def __init__( + self, + processed_data: DataModel, + amplitude_scaling: str | tuple[float, float] = "linear", + n_color_bins: int = 20, + point_size: float = 5.0, + ) -> None: super().__init__() - print(f"Loaded {spike_times.size} spikes from {sorter_path}") + self.processed_data = processed_data - self.spike_times = spike_times - self.spike_amplitudes = spike_amplitudes - self.spike_depths = spike_depths - self.spike_templates = spike_templates - self.templates = templates - self.channel_positions = channel_positions - - self.cfgs = { + self.cfgs: dict[str, Any] = { "right_panel_view_mode": "heatmap", "left_panel_y_axis": { "on": False, @@ -36,15 +37,14 @@ def __init__(self, spike_times, spike_amplitudes, spike_depths, } self.selected_spot = None - self._trace_view_initialized = False self.resize(1400, 820) # Instantiate UI and scatter plot win_left, win_right = self._init_ui() - self._init_scatter_plot(win_left, spike_times, spike_amplitudes, - spike_depths, amplitude_scaling, - n_color_bins, point_size) + + self._init_scatter_plot(win_left, amplitude_scaling, n_color_bins, point_size) + self._init_panel_plot(win_right) # Connect widgets @@ -54,34 +54,40 @@ def __init__(self, spike_times, spike_amplitudes, spike_depths, self.scatter.sigClicked.connect(self.handle_click) self._view_radio_group.idToggled.connect(self.handle_view_radio_toggled) - def _init_scatter_plot(self, win_left, spike_times, spike_amplitudes, - spike_depths, amplitude_scaling, n_color_bins, - point_size): + def _init_scatter_plot( + self, + win_left: pg.GraphicsLayoutWidget, + amplitude_scaling: str | tuple[float, float], + n_color_bins: int, + point_size: float, + ) -> None: """Create the scatter plot on the left panel. Parameters ---------- - win_left : pg.GraphicsLayoutWidget + win_left The left graphics area to host the scatter plot. - spike_times, spike_amplitudes, spike_depths : np.ndarray - Spike data arrays. - amplitude_scaling : str | tuple + amplitude_scaling Colour-scaling mode or explicit (min, max) range. - n_color_bins : int + n_color_bins Number of grey-scale colour bins. - point_size : float + point_size Scatter-point diameter in pixels. """ + spike_times, spike_depths, spike_amplitudes = ( + self.processed_data.get_scatter_data() + ) + + # set amplitude colors + rgba_colors = self.processed_data.compute_amplitude_colors( + amplitude_scaling, n_color_bins + ) + self.p_scatter = win_left.addPlot(row=0, col=0) self.p_scatter.setLabel("bottom", "Time (s)") self.p_scatter.setLabel("left", "Depth (µm)") self.p_scatter.showGrid(x=False, y=False) - # set amplitude colors - rgba_colors = self._compute_amplitude_colors( - spike_amplitudes, amplitude_scaling, n_color_bins - ) - # set axis limits, pad around them slightly x_pad = (spike_times.max() - spike_times.min()) * 0.025 y_pad = (spike_depths.max() - spike_depths.min()) * 0.05 @@ -93,7 +99,8 @@ def _init_scatter_plot(self, win_left, spike_times, spike_amplitudes, yMax=spike_depths.max() + y_pad, ) - # create plot — each point stores its spike index in 'data' for click/tooltip lookup + # create plot — each point stores its spike index + # in 'data' for click/tooltip lookup self.scatter = pg.ScatterPlotItem( spike_times, spike_depths, @@ -105,18 +112,17 @@ def _init_scatter_plot(self, win_left, spike_times, spike_amplitudes, brush=rgba_colors, pen=None, tip=lambda x, y, data: ( - f"x={x:.3f}\ny={y:.1f}\n" - f"amp={self.spike_amplitudes[int(data)]:.2f}" + f"x={x:.3f}\ny={y:.1f}\namp={spike_amplitudes[int(data)]:.2f}" ), ) self.p_scatter.addItem(self.scatter) - def _init_panel_plot(self, win_right): + def _init_panel_plot(self, win_right: pg.GraphicsLayoutWidget) -> None: """Create the template panel plot on the right side. Parameters ---------- - win_right : pg.GraphicsLayoutWidget + win_right The right graphics area to host the panel plot. """ self.panel_plot = win_right.addPlot(row=0, col=0) @@ -124,74 +130,40 @@ def _init_panel_plot(self, win_right): self.panel_plot.setLabel("left", "amplitude") self.panel_plot.showGrid(x=False, y=False) - def _connect_signals(self): - """Wire up Qt signal/slot connections.""" - self.ymin_spin.valueChanged.connect(self.handle_y_spinbox_min) - self.ymax_spin.valueChanged.connect(self.handle_y_spinbox_max) - self._fix_limits_cb.toggled.connect(self.handle_fix_ylim_cb) - self.scatter.sigClicked.connect(self.handle_click) - self._view_radio_group.idToggled.connect(self.handle_view_radio_toggled) - - def handle_view_radio_toggled(self, button_id, checked): + def handle_view_radio_toggled(self, button_id: int, checked: bool) -> None: if not checked: return - mode_map = {0: "max_waveform", 1: "heatmap", 2: "heatmap_all_channels", 3: "trace_view"} + mode_map = {0: "heatmap", 1: "heatmap_all_channels"} mode = mode_map[button_id] self.cfgs["right_panel_view_mode"] = mode - self._trace_view_initialized = False - if mode == "trace_view": - self._limits_page.setVisible(False) - else: - self._limits_page.setVisible(True) - if mode == "max_waveform": - self._fix_limits_cb.setText("Fix y-limits") - self._min_label.setText("Y min:") - self._max_label.setText("Y max:") - else: - self._fix_limits_cb.setText("Fix color limits") - self._min_label.setText("C min:") - self._max_label.setText("C max:") + self._fix_limits_cb.setText("Fix color limits") + self._min_label.setText("C min:") + self._max_label.setText("C max:") if self.selected_spot is not None: - self.set_y_limit() self.update_panel(int(self.selected_spot.data())) - def handle_y_spinbox_min(self, value): + def handle_y_spinbox_min(self, value: float) -> None: self.cfgs["left_panel_y_axis"]["y_min"] = value if self.selected_spot is not None: - self.set_y_limit() self.update_panel(int(self.selected_spot.data())) - def handle_y_spinbox_max(self, value): + def handle_y_spinbox_max(self, value: float) -> None: self.cfgs["left_panel_y_axis"]["y_max"] = value if self.selected_spot is not None: - self.set_y_limit() self.update_panel(int(self.selected_spot.data())) - def handle_fix_ylim_cb(self, active): + def handle_fix_ylim_cb(self, active: bool) -> None: self.ymin_spin.setEnabled(active) self.ymax_spin.setEnabled(active) self.cfgs["left_panel_y_axis"]["on"] = active - self.set_y_limit() if self.selected_spot is None: return self.update_panel(int(self.selected_spot.data())) - def set_y_limit(self): - mode = self.cfgs["right_panel_view_mode"] - if mode == "max_waveform": - if self.cfgs["left_panel_y_axis"]["on"]: - self.panel_plot.setYRange(self.ymin_spin.value(), self.ymax_spin.value(), padding=0) - else: - self.panel_plot.enableAutoRange(axis='y') - elif mode in ("heatmap", "heatmap_all_channels"): - pass # color limits applied during draw - else: - self.panel_plot.enableAutoRange() - - def handle_click(self, _, points, __): + def handle_click(self, _, points, __) -> None: if points is None or len(points) <= 0: return @@ -200,42 +172,29 @@ def handle_click(self, _, points, __): if self.selected_spot is not None: self.selected_spot.setPen(pg.mkPen(None)) - spot.setPen(pg.mkPen('r', width=2)) + spot.setPen(pg.mkPen("r", width=2)) self.selected_spot = spot idx = int(spot.data()) self.update_panel(idx) - def update_panel(self, spike_idx): - template_id = int(self.spike_templates[spike_idx]) - self.panel_plot.setTitle(f"Template {template_id}") - - if self.cfgs["right_panel_view_mode"] == "max_waveform": - self._draw_max_waveform_on_panel(spike_idx) - elif self.cfgs["right_panel_view_mode"] == "trace_view": - self._draw_template_trace_view_on_panel(spike_idx) - else: - self._draw_template_heatmap_on_panel(spike_idx) - - # TODO: carefully check these!!! - - def _draw_max_waveform_on_panel(self, spike_index): - template_waveform = self.get_max_waveform_data(spike_index) - n_samples = template_waveform.size + def update_panel(self, spike_idx: int) -> None: + self.panel_plot.setTitle( + f"Template {self.processed_data.get_template_id(spike_idx)}" + ) - pen = pg.mkPen("k", width=2.5) - self.panel_plot.clear() - self.panel_plot.plot(np.arange(n_samples), template_waveform, pen=pen) + self._draw_template_heatmap_on_panel(spike_idx) - self.panel_plot.setLabel("bottom", "sample") - self.panel_plot.setLabel("left", "amplitude") - self.panel_plot.getAxis("left").setTicks(None) - self.panel_plot.getAxis("left").setStyle(showValues=True) - self.panel_plot.setXRange(0, n_samples, padding=0.05) + def _draw_template_heatmap_on_panel(self, spike_index: int) -> None: + """ """ + template_waveform_2d = self.processed_data.get_template_heatmap( + spike_index, self.cfgs["right_panel_view_mode"] + ) - def _draw_template_heatmap_on_panel(self, spike_index): - template_waveform_2d = self.get_heatmap_data(spike_index) - n_samples, n_chans = template_waveform_2d.shape[0], template_waveform_2d.shape[1] + n_samples, n_chans = ( + template_waveform_2d.shape[0], + template_waveform_2d.shape[1], + ) self.panel_plot.clear() @@ -261,232 +220,36 @@ def _draw_template_heatmap_on_panel(self, spike_index): image_item.setColorMap(cmap) if self.cfgs["left_panel_y_axis"]["on"]: - image_item.setLevels(( - self.ymin_spin.value(), - self.ymax_spin.value(), - )) + image_item.setLevels( + ( + self.ymin_spin.value(), + self.ymax_spin.value(), + ) + ) image_item.setImage(template_waveform_2d) image_item.setRect(0, 0, n_samples, n_chans) self.panel_plot.setXRange(0, n_samples, padding=0.05) - self.panel_plot.setYRange(0, n_chans, padding=0.05) - - def _draw_template_trace_view_on_panel(self, spike_index): - template_idx = self.spike_templates[spike_index] - wv = self.templates[template_idx].copy() * self.spike_amplitudes[spike_index] - n_samples, n_chan = wv.shape - - # Use only channels with non-zero data, unwrapping if KS wrapped them - contains_data_idx, is_wrapped = self._get_nonzero_channel_indices(wv) - wv = wv[:, contains_data_idx] - xc = self.channel_positions[contains_data_idx, 0] - yc = self.channel_positions[contains_data_idx, 1] - - if is_wrapped: - xc, yc = self._make_positions_contiguous(xc, yc) - - chan_spacing, x_spacing = self._compute_trace_spacing(xc, yc) - amp_scale = (chan_spacing * 0.45) / np.max(np.abs(wv)) if np.max(np.abs(wv)) > 0 else 1.0 - - self._plot_traces(wv, xc, yc, n_samples, x_spacing, amp_scale) - self._set_trace_view_range(xc, yc, chan_spacing, x_spacing) - - @staticmethod - def _compute_trace_spacing(xc, yc): - """Compute channel and x spacing from probe positions. + self.panel_plot.setYRange(0, n_chans, padding=0.0) - Returns - ------- - chan_spacing : float - Minimum y-distance between adjacent channels. - x_spacing : float - Minimum x-distance between adjacent channel columns. - """ - unique_y = np.unique(yc) - chan_spacing = np.min(np.diff(np.sort(unique_y))) if len(unique_y) > 1 else 1.0 - - unique_x = np.unique(xc) - x_spacing = np.min(np.diff(np.sort(unique_x))) if len(unique_x) > 1 else 20.0 - - return chan_spacing, x_spacing - - def _plot_traces(self, wv, xc, yc, n_samples, x_spacing, amp_scale): - """Draw one waveform trace per channel on the panel plot.""" - self.panel_plot.clear() - - # Time axis in physical units, scaled to fit within ~90% of x_spacing - t = np.arange(-n_samples // 2, n_samples // 2, 1, dtype=np.float32) - t *= (x_spacing * 0.9) / n_samples - - for ii, (xi, yi) in enumerate(zip(xc, yc)): - self.panel_plot.plot( - xi + t, yi + wv[:, ii] * amp_scale, - pen=pg.mkPen('k', width=0.5), - ) - - self.panel_plot.setLabel("bottom", "x position") - self.panel_plot.setLabel("left", "y position (\u00b5m)") - self.panel_plot.getAxis("left").setTicks(None) - self.panel_plot.getAxis("left").setStyle(showValues=True) - - def _set_trace_view_range(self, xc, yc, chan_spacing, x_spacing): - """Set the view range for the trace view, centering on first click - and preserving zoom on subsequent clicks.""" - y_center = (yc.min() + yc.max()) / 2 - x_center = (xc.min() + xc.max()) / 2 - - if not self._trace_view_initialized: - y_pad = (yc.max() - yc.min()) * 0.15 + chan_spacing - x_pad = (xc.max() - xc.min()) * 0.15 + x_spacing - self.panel_plot.setXRange(xc.min() - x_pad, xc.max() + x_pad, padding=0) - self.panel_plot.setYRange(yc.min() - y_pad, yc.max() + y_pad, padding=0) - self._trace_view_initialized = True - else: - view_box = self.panel_plot.getViewBox() - [[x_lo, x_hi], [y_lo, y_hi]] = view_box.viewRange() - y_half = (y_hi - y_lo) / 2 - x_half = (x_hi - x_lo) / 2 - self.panel_plot.setXRange(x_center - x_half, x_center + x_half, padding=0) - self.panel_plot.setYRange(y_center - y_half, y_center + y_half, padding=0) - - def get_max_waveform_data(self, spike_index): - template_idx = self.spike_templates[spike_index] - scaled_template = self.templates[template_idx, :, :] * self.spike_amplitudes[spike_index] - peak_ch = np.argmax(np.max(np.abs(scaled_template), axis=0)) - return scaled_template[:, peak_ch] - - @staticmethod - def _compute_amplitude_colors(spike_amplitudes, amplitude_scaling, n_color_bins): - """Map spike amplitudes to RGBA colours via grey-scale binning. - - Parameters - ---------- - spike_amplitudes : np.ndarray - (num_spikes,) raw amplitude values. - amplitude_scaling : {"linear", "log2", "log10"} | tuple - Scaling mode. A 2-tuple ``(min, max)`` fixes the colour - range explicitly. - n_color_bins : int - Number of grey-scale bins. - - Returns - ------- - np.ndarray - (num_spikes, 4) uint8 RGBA values in [0, 255]. - """ - amp_values = spike_amplitudes.copy() - - if isinstance(amplitude_scaling, tuple): - amp_min, amp_max = amplitude_scaling - elif amplitude_scaling == "log2": - amp_values = np.log2(amp_values) - amp_min, amp_max = amp_values.min(), amp_values.max() - elif amplitude_scaling == "log10": - amp_values = np.log10(amp_values) - amp_min, amp_max = amp_values.min(), amp_values.max() - else: # "linear" - amp_min, amp_max = amp_values.min(), amp_values.max() - - color_bins = np.linspace(amp_min, amp_max, n_color_bins) - gray_colors = plt.get_cmap("gray")(np.linspace(0, 1, n_color_bins))[::-1] - bin_indices = np.clip( - np.searchsorted(color_bins, amp_values, side="right") - 1, - 0, n_color_bins - 2, - ) - return (gray_colors[bin_indices] * 255).astype(np.uint8) - - @staticmethod - def _get_nonzero_channel_indices(scaled_template): - """Get indices of channels with data, unwrapping if KS wrapped them. - - Kilosort can wrap template channel assignments around the probe - boundaries (e.g. channels [0, 1, 2, 380, 381, 382]). This detects - the wrap and reorders so the high-index group comes first, - giving a spatially contiguous result. - - Returns - ------- - contains_data_idx : np.ndarray - Channel indices with data, reordered if wrapping detected. - is_wrapped : bool - True if wrapping was detected and corrected. - """ - # KS templates have zeros on inactive channels; checking row 0 is - # sufficient because active channels always have non-zero values at - # the first sample (the template extends across all samples). - contains_data_idx = np.where(scaled_template[0, :] != 0)[0] - - if len(contains_data_idx) < 2: - return contains_data_idx, False - - # Check for a gap (non-contiguous indices) - diffs = np.diff(contains_data_idx) - gap_positions = np.where(diffs > 1)[0] - - if len(gap_positions) == 1: - # Wrapped: split at gap, put the higher-index group first - split = gap_positions[0] + 1 - contains_data_idx = np.concatenate([ - contains_data_idx[split:], - contains_data_idx[:split], - ]) - return contains_data_idx, True - - return contains_data_idx, False - - @staticmethod - def _make_positions_contiguous(xc, yc): - """Remap channel positions so wrapped channels sit contiguously. - - When KS wraps, channels from the top and bottom of the probe - end up in the same template. We shift the lower-position group - to sit just above the higher-position group (or vice versa) - so they display as one contiguous block. - """ - sorted_y = np.sort(np.unique(yc)) - if len(sorted_y) < 2: - return xc, yc.copy() - - y_diffs = np.diff(sorted_y) - largest_gap_idx = np.argmax(y_diffs) - gap_size = y_diffs[largest_gap_idx] - typical_spacing = np.median(y_diffs) - - if gap_size > typical_spacing * 3: - gap_threshold = sorted_y[largest_gap_idx] + gap_size / 2 - yc_new = yc.copy() - upper_mask = yc >= gap_threshold - shift = yc[upper_mask].min() - yc[~upper_mask].max() - typical_spacing - yc_new[upper_mask] -= shift - return xc, yc_new - - return xc, yc.copy() - - def get_heatmap_data(self, spike_index): - """""" - template_idx = self.spike_templates[spike_index] - scaled_template = self.templates[template_idx, :, :] * self.spike_amplitudes[spike_index] - - if self.cfgs["right_panel_view_mode"] == "heatmap_all_channels": - scaled_template = scaled_template.copy() # TODO: check if this is necessary - scaled_template[:, scaled_template[0, :] == 0] = np.nan - else: - contains_data_idx, _ = self._get_nonzero_channel_indices(scaled_template) - scaled_template = scaled_template[:, contains_data_idx] - - return scaled_template - - - # UI - To possibly be moved to QtDesigner + # UI # ---------------------------------------------------------------------------------- - def _init_ui(self): + def _connect_signals(self) -> None: + """Wire up Qt signal/slot connections.""" + self.ymin_spin.valueChanged.connect(self.handle_y_spinbox_min) + self.ymax_spin.valueChanged.connect(self.handle_y_spinbox_max) + self._fix_limits_cb.toggled.connect(self.handle_fix_ylim_cb) + self.scatter.sigClicked.connect(self.handle_click) + self._view_radio_group.idToggled.connect(self.handle_view_radio_toggled) + + def _init_ui(self) -> tuple[pg.GraphicsLayoutWidget, pg.GraphicsLayoutWidget]: """Build the widget layout: splitter, controls bar, radio buttons, spinboxes. Returns ------- - win_left : pg.GraphicsLayoutWidget + win_left Left graphics area (for the scatter plot). - win_right : pg.GraphicsLayoutWidget + win_right Right graphics area (for the template panel). """ # Core layout @@ -513,7 +276,7 @@ def _init_ui(self): splitter.setSizes([int(total * 0.75), int(total * 0.25)]) # Controls Bar — below the full splitter - # -------------------------------------------------------------------------------------------------------------- + # ------------------------------------------------------- controls_widget = QtWidgets.QWidget() controls_layout = QtWidgets.QVBoxLayout(controls_widget) controls_layout.setContentsMargins(6, 6, 6, 6) @@ -525,22 +288,18 @@ def _init_ui(self): radio_layout.setContentsMargins(0, 0, 0, 0) radio_layout.setSpacing(12) - self.radio_max_wf = QtWidgets.QRadioButton("Max waveform") - self.radio_heatmap = QtWidgets.QRadioButton("Heatmap") - self.radio_heatmap_all = QtWidgets.QRadioButton("Heatmap (all channels)") - self.radio_trace_view = QtWidgets.QRadioButton("Trace view") + self.radio_heatmap = QtWidgets.QRadioButton("Template heatmap") + self.radio_heatmap_all = QtWidgets.QRadioButton( + "Template heatmap (all channels)" + ) self.radio_heatmap.setChecked(True) self._view_radio_group = QtWidgets.QButtonGroup(self) - self._view_radio_group.addButton(self.radio_max_wf, 0) - self._view_radio_group.addButton(self.radio_heatmap, 1) - self._view_radio_group.addButton(self.radio_heatmap_all, 2) - self._view_radio_group.addButton(self.radio_trace_view, 3) + self._view_radio_group.addButton(self.radio_heatmap, 0) + self._view_radio_group.addButton(self.radio_heatmap_all, 1) - radio_layout.addWidget(self.radio_max_wf) radio_layout.addWidget(self.radio_heatmap) radio_layout.addWidget(self.radio_heatmap_all) - radio_layout.addWidget(self.radio_trace_view) radio_layout.addStretch() controls_layout.addWidget(radio_row) @@ -550,7 +309,7 @@ def _init_ui(self): limits_layout.setContentsMargins(0, 0, 0, 0) limits_layout.setSpacing(8) - self._fix_limits_cb = QtWidgets.QCheckBox("Fix y-limits") + self._fix_limits_cb = QtWidgets.QCheckBox("Fix color limits") self.ymin_spin = QtWidgets.QDoubleSpinBox() self.ymax_spin = QtWidgets.QDoubleSpinBox() for spin in (self.ymin_spin, self.ymax_spin): @@ -564,8 +323,8 @@ def _init_ui(self): self.ymax_spin.setValue(self.cfgs["left_panel_y_axis"]["y_max"]) limits_layout.addWidget(self._fix_limits_cb) - self._min_label = QtWidgets.QLabel("Y min:") - self._max_label = QtWidgets.QLabel("Y max:") + self._min_label = QtWidgets.QLabel("C min:") + self._max_label = QtWidgets.QLabel("C max:") limits_layout.addWidget(self._min_label) limits_layout.addWidget(self.ymin_spin) limits_layout.addWidget(self._max_label) diff --git a/driftmap_viewer/interactive/driftmap_view.py b/driftmap_viewer/interactive/driftmap_view.py deleted file mode 100644 index af36959..0000000 --- a/driftmap_viewer/interactive/driftmap_view.py +++ /dev/null @@ -1,277 +0,0 @@ -from mpl_plotting.driftmapviewer_new import get_drift_map_plot, _plot_kilosort_drift_map_raster -import matplotlib.pyplot as plt -from ks_extractors import kilosort1_3 -from ks_extractors import kilosort_4 -from ks_extractors import helpers -from pathlib import Path -import numpy as np -from .driftmap_plot_widget import DriftmapPlotWidget -import numpy as np -import pyqtgraph as pg -from pyqtgraph.Qt import QtWidgets, QtCore -import matplotlib.pyplot as plt - - -class DriftMapView(): - """Load Kilosort sorter output and provide interactive or static drift map plots. - - On construction, spike data is loaded from a Kilosort output directory - and stored as read-only arrays. Plotting methods apply optional - filtering (noise exclusion, amplitude filtering, decimation) before - handing the data to a plot backend. - - Parameters - ---------- - sorter_path : str | Path - Path to a Kilosort sorter output directory. Must contain - exactly one ``kilosort*.log`` file used to detect the KS version. - - Attributes - ---------- - spike_times : np.ndarray - (num_spikes,) spike times (seconds for KS 1-3, samples for KS4). - spike_amplitudes : np.ndarray - (num_spikes,) spike amplitudes. - spike_depths : np.ndarray - (num_spikes,) spike depths along the probe (µm). - spike_templates : np.ndarray - (num_spikes,) template index assigned to each spike. - templates : np.ndarray - (num_templates, num_samples, num_channels) template waveforms. - channel_positions : np.ndarray - (num_channels, 2) x/y positions of each channel on the probe. - - TODO - ---- - - Evaluate memory cost of holding all arrays; consider lazy / mmap loading. - - Harmonise spike_times units (seconds everywhere). - """ - def __init__(self, sorter_path): - """Load spike data from a Kilosort output directory. - - Parameters - ---------- - sorter_path : str | Path - Path to the Kilosort sorter output. - - Raises - ------ - AssertionError - If the directory does not contain exactly one ``kilosort*.log`` - file, or if the loaded spike arrays have mismatched sizes. - """ - self.sorter_path = Path(sorter_path) - - log_file = list(self.sorter_path.glob("kilosort*.log")) - assert len(log_file) == 1 - self.ks_version = Path(log_file[0]).name.split(".")[0] - - func = kilosort_4.get_spikes_info_ks4 if self.ks_version == "kilosort4" else kilosort1_3.get_spikes_info_ks1_3 - - ( - self.spike_times, - self.spike_amplitudes, - self.spike_depths, - self.spike_templates, - self.templates, - self.channel_positions - ) = func( - self.sorter_path - ) - - assert self.spike_times.size == self.spike_amplitudes.size == self.spike_depths.size == self.spike_templates.size - - self.spike_times.flags.writeable = False - self.spike_amplitudes.flags.writeable = False - self.spike_depths.flags.writeable = False - self.spike_templates.flags.writeable = False - self.templates.flags.writeable = False - self.channel_positions.flags.writeable = False - - def _process_data( - self, - exclude_noise, - decimate, - filter_amplitude_mode, - filter_amplitude_values - ): - """Filter and subsample the loaded spike data. - - Operations are applied in order: decimation → noise exclusion → - amplitude filtering → masking. Decimation is applied first as a - performance knob to thin the full dataset before further filtering. - - Parameters - ---------- - exclude_noise : bool - If ``True``, spikes belonging to clusters labelled "noise" in - the Kilosort cluster groups file are removed. - decimate : int | False - Keep every *n*-th spike. Applied first to reduce the dataset - before noise/amplitude filters. ``False`` disables decimation. - filter_amplitude_mode : {"percentile", "absolute"} | None - How ``filter_amplitude_values`` is interpreted. - ``None`` disables amplitude filtering. - filter_amplitude_values : tuple of float - (low, high) bounds. Interpreted as percentile ranks or - absolute amplitude values depending on ``filter_amplitude_mode``. - - Returns - ------- - spike_times : np.ndarray - spike_amplitudes : np.ndarray - spike_depths : np.ndarray - spike_templates : np.ndarray - Filtered copies (views when no filtering is needed) of the - corresponding instance arrays. - """ - # Select a view for now, this may be copied depending on options (e.g. decimate) - spike_times = self.spike_times - spike_amplitudes = self.spike_amplitudes - spike_depths = self.spike_depths - spike_templates = self.spike_templates - - keep_bool_mask = None - - if exclude_noise: - keep_bool_mask = ~helpers.get_noise_mask( - self.sorter_path - ) - - if filter_amplitude_mode is not None: - assert filter_amplitude_mode in ["percentile", "absolute"] - - if filter_amplitude_mode == "percentile": - min_val, max_val = np.percentile( - spike_amplitudes, filter_amplitude_values - ) - else: - min_val, max_val = filter_amplitude_values - - if keep_bool_mask is None: - keep_bool_mask = np.ones(spike_amplitudes.size, dtype=bool) - - keep_bool_mask[spike_amplitudes < min_val] = False - keep_bool_mask[spike_amplitudes > max_val] = False - - if keep_bool_mask is not None: - spike_times = spike_times[keep_bool_mask] - spike_amplitudes = spike_amplitudes[keep_bool_mask] - spike_depths = spike_depths[keep_bool_mask] - spike_templates = spike_templates[keep_bool_mask] - - if decimate: - spike_times = spike_times[:: decimate] - spike_amplitudes = spike_amplitudes[:: decimate] - spike_depths = spike_depths[:: decimate] - spike_templates = spike_templates[:: decimate] - - return spike_times, spike_amplitudes, spike_depths, spike_templates - - def drift_map_plot_interactive( - self, - decimate=False, - exclude_noise=True, - amplitude_scaling="linear", - n_color_bins=20, - point_size=7.5, - filter_amplitude_mode=None, - filter_amplitude_values=(), - ): - """Create an interactive pyqtgraph-based drift map widget. - - Parameters - ---------- - decimate : int | False - Keep every *n*-th spike. ``False`` disables decimation. - exclude_noise : bool - Remove spikes labelled as noise. - amplitude_scaling : {"linear", "log2", "log10"} | tuple - Colour-scaling mode. A 2-tuple ``(min, max)`` fixes the - colour range explicitly. - n_color_bins : int - Number of grey-scale colour bins for amplitude. - point_size : float - Scatter-point diameter in pixels. - filter_amplitude_mode : {"percentile", "absolute"} | None - Amplitude filtering mode (see ``_process_data``). - filter_amplitude_values : tuple of float - Bounds for amplitude filtering. - - Returns - ------- - DriftmapPlotWidget - The pyqtgraph widget (already populated but not yet shown). - """ - ( - spike_times, - spike_amplitudes, - spike_depths, - spike_templates, - ) = self._process_data( - exclude_noise, - decimate, - filter_amplitude_mode, - filter_amplitude_values - ) - - self.plot = DriftmapPlotWidget( - spike_times, - spike_amplitudes, - spike_depths, - spike_templates, - self.templates, - self.channel_positions, - amplitude_scaling=amplitude_scaling, - n_color_bins=n_color_bins, - point_size=point_size, - sorter_path=self.sorter_path - ) - - return self.plot - - # ---------------------------------------------------------------------------------- - # TODO MATPLOTLIB - # ---------------------------------------------------------------------------------- - - def _drift_map_plot_matplotlib(self, - decimate=False, - exclude_noise=True, - log_transform_amplitudes=True, - filter_amplitude_mode=None, - filter_amplitude_values=(), - ): - ( - spike_times, - spike_amplitudes, - spike_depths, - spike_templates, - ) = self._process_data( - exclude_noise, - decimate, - filter_amplitude_mode, - filter_amplitude_values - ) - - fig = plt.figure(figsize=(10, 10 * (6 / 8))) - raster_axis = fig.add_subplot() - - _plot_kilosort_drift_map_raster( - spike_times, - spike_amplitudes, - spike_depths, - axis=raster_axis, - ) - - # histogram - if False: - hist_axis.set_xlabel("count") - raster_axis.set_xlabel("time") - hist_axis.set_ylabel("y position") - - bin_centers, counts = _compute_activity_histogram( - spike_amplitudes, spike_depths, weight_histogram_by_amplitude - ) - hist_axis.plot(counts, bin_centers, color="black", linewidth=1) - - return fig diff --git a/driftmap_viewer/interactive/multi_session_drift_map.py b/driftmap_viewer/interactive/multi_session_drift_map.py index 295f700..c789610 100644 --- a/driftmap_viewer/interactive/multi_session_drift_map.py +++ b/driftmap_viewer/interactive/multi_session_drift_map.py @@ -1,5 +1,7 @@ import math + from PySide6 import QtWidgets + from .driftmap_plot_widget import DriftmapPlotWidget @@ -12,14 +14,14 @@ class MultiSessionDriftmapWidget(QtWidgets.QWidget): Parameters ---------- - panels : list[DriftmapPlotWidget] + panels Drift-map widgets to arrange in the grid. - grid : tuple[int, int] | None + grid Explicit ``(n_rows, n_cols)`` layout. If ``None``, a roughly square layout is computed automatically. - width : int + width Width allocated per panel column (pixels). - height : int + height Height allocated per panel row (pixels). """ @@ -51,22 +53,22 @@ def _compute_grid_dimensions( Parameters ---------- - num_panels : int + num_panels Total number of panels to arrange. - grid : tuple[int, int] | None + grid User-specified ``(n_rows, n_cols)``. If ``None``, a roughly square grid is computed automatically. Returns ------- - n_rows, n_cols : int + tuple of int + (n_rows, n_cols). """ if grid is not None: n_rows, n_cols = grid if n_rows * n_cols != num_panels: raise ValueError( - f"grid {grid} expects {n_rows * n_cols} panels " - f"but got {num_panels}" + f"grid {grid} expects {n_rows * n_cols} panels but got {num_panels}" ) return n_rows, n_cols @@ -85,9 +87,9 @@ def _populate_grid( Parameters ---------- - panels : list[DriftmapPlotWidget] + panels Widgets to add. - n_rows, n_cols : int + n_rows, n_cols Grid dimensions. """ grid_layout = QtWidgets.QGridLayout(self) @@ -112,9 +114,9 @@ def _link_y_axes(panels: list[DriftmapPlotWidget]) -> None: Parameters ---------- - panels : list[DriftmapPlotWidget] + panels Must contain at least one panel. """ ref = panels[0].p_scatter for panel in panels[1:]: - panel.p_scatter.setYLink(ref) \ No newline at end of file + panel.p_scatter.setYLink(ref) diff --git a/driftmap_viewer/kilosort_4.py b/driftmap_viewer/kilosort_4.py deleted file mode 100644 index f526fd9..0000000 --- a/driftmap_viewer/kilosort_4.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from pathlib import Path - -import numpy as np -import helpers - - -def get_spikes_info_ks4( - sorter_output: Path, -): # -> tuple[np.ndarray, ...]: - - spike_times = np.load(sorter_output / "spike_times.npy") - spike_amplitudes = np.load(sorter_output / "amplitudes.npy") - spike_depths = np.load(sorter_output / "spike_positions.npy")[:, 1] - - return spike_times, spike_amplitudes, spike_depths diff --git a/driftmap_viewer/ks_extractors/__init__.py b/driftmap_viewer/ks_extractors/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/driftmap_viewer/ks_extractors/helpers.py b/driftmap_viewer/ks_extractors/helpers.py deleted file mode 100644 index 36cb8ee..0000000 --- a/driftmap_viewer/ks_extractors/helpers.py +++ /dev/null @@ -1,115 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING -from pathlib import Path - -import pandas as pd -import numpy as np - - -def load_cluster_groups(cluster_path: Path) -> tuple[np.ndarray, ...]: - """ - Load kilosort `cluster_groups` file, that contains a table of - quality assignments, one per unit. These can be "noise", "mua", "good" - or "unsorted". - - There is some slight formatting differences between the `.tsv` and `.csv` - versions, presumably from different kilosort versions. - - This function was ported from Nick Steinmetz's `spikes` repository MATLAB code, - https://github.com/cortex-lab/spikes - - Parameters - ---------- - cluster_path : Path - The full filepath to the `cluster_groups` tsv or csv file. - - Returns - ------- - cluster_ids : np.ndarray - (num_clusters,) Array of (integer) unit IDs. - - cluster_groups : np.ndarray - (num_clusters,) Array of (integer) unit quality assignments, see code - below for mapping to "noise", "mua", "good" and "unsorted". - """ - cluster_groups_table = pd.read_csv(cluster_path, sep="\t") - - group_key = cluster_groups_table.columns[1] # "groups" (csv) or "KSLabel" (tsv) - - for key, _id in zip( - ["noise", "mua", "good", "unsorted"], - ["0", "1", "2", "3"], - # required as str to avoid pandas replace downcast FutureWarning - ): - cluster_groups_table[group_key] = cluster_groups_table[group_key].replace(key, _id) - - cluster_ids = cluster_groups_table["cluster_id"].to_numpy() - cluster_groups = cluster_groups_table[group_key].astype(int).to_numpy() - - return cluster_ids, cluster_groups - - -def get_noise_mask(sorter_output: Path) -> np.ndarray: - """Build a boolean mask identifying spikes that belong to noise clusters. - - Loads ``spike_clusters.npy`` and the cluster-groups file - (``cluster_groups.csv`` or ``cluster_group.tsv``) from the sorter - output directory. Spikes whose cluster is labelled *noise* - (group == 0) are marked ``True``. - - Parameters - ---------- - sorter_output : Path - Path to the Kilosort sorter output directory. - - Returns - ------- - np.ndarray - (num_spikes,) boolean array — ``True`` for spikes belonging to - a noise cluster. - - Raises - ------ - FileNotFoundError - If ``spike_clusters.npy`` is not found. - ValueError - If neither ``cluster_groups.csv`` nor ``cluster_group.tsv`` - exists in ``sorter_output``. - """ - if (cluster_path := sorter_output / "spike_clusters.npy").is_file(): - spike_clusters = np.load(cluster_path) - else: - raise FileNotFoundError("spike_clusters.npy does not exist.") - - if not ( - (cluster_path := sorter_output / "cluster_groups.csv").is_file() - or (cluster_path := sorter_output / "cluster_group.tsv").is_file() - ): - raise ValueError( - f"`exclude_noise` is `True` but there is no `cluster_groups.csv/.tsv` " - f"in the sorting output at: {sorter_output}" - ) - - cluster_ids, cluster_groups = load_cluster_groups(cluster_path) - - noise_cluster_ids = cluster_ids[cluster_groups == 0] - - exclude_bool_mask = np.isin(spike_clusters.ravel(), noise_cluster_ids) - - return exclude_bool_mask - - -def get_pooled_amplitudes(paths: list[Path]) -> np.ndarray: - """Load and concatenate amplitudes.npy from multiple sorter output paths. - - Parameters - ---------- - paths : list of Path - List of sorter output directories, each containing amplitudes.npy. - - Returns - ------- - np.ndarray - Concatenated amplitudes from all paths. - """ - return np.concatenate([np.load(Path(p) / "amplitudes.npy").ravel() for p in paths]) diff --git a/driftmap_viewer/ks_extractors/kilosort1_3.py b/driftmap_viewer/ks_extractors/kilosort1_3.py deleted file mode 100644 index e66e988..0000000 --- a/driftmap_viewer/ks_extractors/kilosort1_3.py +++ /dev/null @@ -1,243 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from pathlib import Path - import numpy as np -import numpy as np -from spikeinterface.core import read_python - - -def get_spikes_info_ks1_3( - sorter_output: str | Path, -) -> tuple[np.ndarray, ...]: - """ - Compute the amplitude and depth of all detected spikes from the kilosort output. - - This function was ported from Nick Steinmetz's `spikes` repository - MATLAB code, https://github.com/cortex-lab/spikes - - Parameters - ---------- - sorter_output : str | Path - Path to the kilosort run sorting output. - - Returns - ------- - spike_times : np.ndarray - (num_spikes,) array of spike times. - spike_amplitudes : np.ndarray - (num_spikes,) array of corresponding spike amplitudes. - spike_depths : np.ndarray - (num_spikes,) array of corresponding depths (probe y-axis location). - - Notes - ----- - In `_template_positions_amplitudes` spike depths is calculated as simply the template - depth, for each spike (so it is the same for all spikes in a cluster). Here we need - to find the depth of each individual spike, using its low-dimensional projection. - `pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike. - Taking the first component, the subset of 32 channels associated with this - spike are indexed to get the actual channel locations (in um). Then, the channel - locations are weighted by their PC values. - """ - if isinstance(sorter_output, str): - sorter_output = Path(sorter_output) - - params = _load_ks_dir(sorter_output, load_pcs=True) - - # Compute spike depths - pc_features = params["pc_features"][:, 0, :] - pc_features[pc_features < 0] = 0 - - # Get the channel indexes corresponding to the 32 channels from the PC. - spike_features_indices = params["pc_features_indices"][params["spike_templates"], :] - - ycoords = params["channel_positions"][:, 1] - spike_feature_ycoords = ycoords[spike_features_indices] - - spike_depths = np.sum(spike_feature_ycoords * pc_features ** 2, axis=1) / np.sum( - pc_features ** 2, axis=1) - - # Compute amplitudes, scale if required and drop un-localised spikes before returning. - spike_amplitudes, _, _, _, unwhite_templates, *_ = _template_positions_amplitudes( - params["templates"], - params["whitening_matrix_inv"], - ycoords, - params["spike_templates"], - params["temp_scaling_amplitudes"], - ) - - return params["spike_times"], spike_amplitudes, spike_depths, params["spike_templates"].squeeze(), unwhite_templates, params["channel_positions"] - - -def _template_positions_amplitudes( - templates: np.ndarray, - inverse_whitening_matrix: np.ndarray, - ycoords: np.ndarray, - spike_templates: np.ndarray, - template_scaling_amplitudes: np.ndarray, -) -> tuple[np.ndarray, ...]: - """ - Calculate the amplitude and depths of (unwhitened) templates and spikes. - - This function was ported from Nick Steinmetz's `spikes` repository - MATLAB code, https://github.com/cortex-lab/spikes - - Parameters - ---------- - templates : np.ndarray - (num_clusters, num_samples, num_channels) array of templates. - inverse_whitening_matrix: np.ndarray - Inverse of the whitening matrix used in KS preprocessing, used to - unwhiten templates. - ycoords : np.ndarray - (num_channels,) array of the y-axis (depth) channel positions. - spike_templates : np.ndarray - (num_spikes,) array indicating the template associated with each spike. - template_scaling_amplitudes : np.ndarray - (num_spikes,) array holding the scaling amplitudes, by which the - template was scaled to match each spike. - - Returns - ------- - spike_amplitudes : np.ndarray - (num_spikes,) array of the amplitude of each spike. - spike_depths : np.ndarray - (num_spikes,) array of the depth (probe y-axis) of each spike. Note - this is just the template depth for each spike (i.e. depth of all spikes - from the same cluster are identical). - template_amplitudes : np.ndarray - (num_templates,) Amplitude of each template, calculated as average of spike amplitudes. - template_depths : np.ndarray - (num_templates,) array of the depth of each template. - unwhite_templates : np.ndarray - Unwhitened templates (num_clusters, num_samples, num_channels). - trough_peak_durations : np.ndarray - (num_templates, ) array of durations from trough to peak for each template waveform - waveforms : np.ndarray - (num_templates, num_samples) Waveform of each template, taken as the signal on the maximum loading channel. - """ - # Unwhiten the template waveforms - unwhite_templates = np.zeros_like(templates) - for idx, template in enumerate(templates): - unwhite_templates[idx, :, :] = templates[idx, :, :] @ inverse_whitening_matrix - - # First, calculate the depth of each template from the amplitude - # on each channel by the center of mass method. - - # Take the max amplitude for each channel, then use the channel - # with most signal as template amplitude. Zero any small channel amplitudes. - template_amplitudes_per_channel = np.max(unwhite_templates, axis=1) - np.min(unwhite_templates, axis=1) - - template_amplitudes_unscaled = np.max(template_amplitudes_per_channel, axis=1) - - threshold_values = 0.3 * template_amplitudes_unscaled - template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0 - - # Calculate the template depth as the center of mass based on channel amplitudes - template_depths = np.sum(template_amplitudes_per_channel * ycoords[np.newaxis, :], axis=1) / np.sum( - template_amplitudes_per_channel, axis=1 - ) - - # Next, find the depth of each spike based on its template. Recompute the template - # amplitudes as the average of the spike amplitudes ('since - # tempScalingAmps are equal mean for all templates') - spike_amplitudes = template_amplitudes_unscaled[spike_templates] * template_scaling_amplitudes - - # Take the average of all spike amplitudes to get actual template amplitudes - # (since tempScalingAmps are equal mean for all templates) - num_indices = templates.shape[0] - sum_per_index = np.zeros(num_indices, dtype=np.float64) - np.add.at(sum_per_index, spike_templates, spike_amplitudes) - counts = np.bincount(spike_templates, minlength=num_indices) - template_amplitudes = np.divide(sum_per_index, counts, out=np.zeros_like(sum_per_index), where=counts != 0) - - # Each spike's depth is the depth of its template - spike_depths = template_depths[spike_templates] - - # Get channel with the largest amplitude (take that as the waveform) - max_site = np.argmax(np.max(np.abs(templates), axis=1), axis=1) - - # Use template channel with max signal as waveform - waveforms = np.empty(templates.shape[:2]) - for idx, template in enumerate(templates): - waveforms[idx, :] = templates[idx, :, max_site[idx]] - - # Get trough-to-peak time for each template. Find the trough as the - # minimum signal for the template waveform. The duration (in - # samples) is the num samples from trough to the largest value - # following the trough. - waveform_trough = np.argmin(waveforms, axis=1) - - trough_peak_durations = np.zeros(waveforms.shape[0]) - for idx, tmp_max in enumerate(waveforms): - trough_peak_durations[idx] = np.argmax(tmp_max[waveform_trough[idx] :]) - - return ( - spike_amplitudes, - spike_depths, - template_depths, - template_amplitudes, - unwhite_templates, - trough_peak_durations, - waveforms, - ) - -def _load_ks_dir(sorter_output: Path, load_pcs: bool = False) -> dict: - """ - Loads the output of Kilosort into a `params` dict. - - This function was ported from Nick Steinmetz's `spikes` repository MATLAB - code, https://github.com/cortex-lab/spikes - - Parameters - ---------- - sorter_output : Path - Path to the kilosort run sorting output. - exclude_noise : bool - If `True`, units labelled as "noise` are removed from all - returned arrays (i.e. both units and associated spikes are dropped). - load_pcs : bool - If `True`, principal component (PC) features are loaded. - - Parameters - ---------- - params : dict - A dictionary of parameters combining both the kilosort `params.py` - file as data loaded from `npy` files. The contents of the `npy` - files can be found in the Phy documentation. - - Notes - ----- - When merging and splitting in `Phy`, all changes are made to the - `spike_clusters.npy` (cluster assignment per spike) and `cluster_groups` - csv/tsv which contains the quality assignment (e.g. "noise") for each cluster. - As this function strips the spikes and units based on only these two - data structures, they will work following manual reassignment in Phy. - """ - params = read_python(sorter_output / "params.py") - - spike_times = np.load(sorter_output / "spike_times.npy") / params["sample_rate"] - spike_templates = np.load(sorter_output / "spike_templates.npy") - - temp_scaling_amplitudes = np.load(sorter_output / "amplitudes.npy") - - if load_pcs: - pc_features = np.load(sorter_output / "pc_features.npy") - pc_features_indices = np.load(sorter_output / "pc_feature_ind.npy") - else: - pc_features = pc_features_indices = None - - new_params = { - "spike_times": spike_times.squeeze(), - "spike_templates": spike_templates.squeeze(), - "pc_features": pc_features, - "pc_features_indices": pc_features_indices, - "temp_scaling_amplitudes": temp_scaling_amplitudes.squeeze(), - "channel_positions": np.load(sorter_output / "channel_positions.npy"), - "templates": np.load(sorter_output / "templates.npy"), - "whitening_matrix_inv": np.load(sorter_output / "whitening_mat_inv.npy"), - } - params.update(new_params) - - return params diff --git a/driftmap_viewer/ks_extractors/kilosort_4.py b/driftmap_viewer/ks_extractors/kilosort_4.py deleted file mode 100644 index b7604b5..0000000 --- a/driftmap_viewer/ks_extractors/kilosort_4.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from pathlib import Path - -import numpy as np -from . import helpers - - -def get_spikes_info_ks4( - sorter_output: Path, -): # -> tuple[np.ndarray, ...]: - - # TODO: kept_spikes is not always the same size as other spike data - # Currently not used for loading - # kept_spikes = np.load(sorter_output / "kept_spikes.npy") - - spike_times = np.load(sorter_output / "spike_times.npy") - spike_amplitudes = np.load(sorter_output / "amplitudes.npy") - spike_depths = np.load(sorter_output / "spike_positions.npy")[:, 1] - spike_templates = np.load(sorter_output / "spike_templates.npy") # rename spike_tempaltes_idx? - - templates = np.load(sorter_output / "templates.npy") # rename unwihten? - channel_positions = np.load(sorter_output / "channel_positions.npy") - - return spike_times, spike_amplitudes, spike_depths, spike_templates, templates, channel_positions diff --git a/driftmap_viewer/mpl_plotting.py b/driftmap_viewer/mpl_plotting.py new file mode 100644 index 0000000..c052170 --- /dev/null +++ b/driftmap_viewer/mpl_plotting.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import matplotlib.pyplot as plt +from matplotlib.figure import Figure + +if TYPE_CHECKING: + from driftmap_viewer.data_model import DataModel + + +def plot_matplotlib( + processed_data: DataModel, + amplitude_scaling: str | tuple[float, float], + n_color_bins: int, + point_size: float, + add_histogram_plot: bool, + weight_histogram_by_amplitude: bool, +) -> Figure: + """Render a static matplotlib drift-map figure from pre-processed data. + + Parameters + ---------- + processed_data : + amplitude_scaling : + + n_color_bins : + + point_size : + + add_histogram_plot : + + weight_histogram_by_amplitude : + + Returns + ------- + matplotlib.figure.Figure + The drift-map figure. + """ + # Setup axis and plot the raster drift map + fig = plt.figure(figsize=(10, 10 * (6 / 8))) + + # Set up the axes + if add_histogram_plot: + gs = fig.add_gridspec(1, 2, width_ratios=[1, 5]) + hist_axis = fig.add_subplot(gs[0]) + raster_axis = fig.add_subplot(gs[1], sharey=hist_axis) + else: + raster_axis = fig.add_subplot() + + # Plot the raster plot + spike_times, spike_depths, _ = processed_data.get_scatter_data() + + rgba_colors = processed_data.compute_amplitude_colors( + amplitude_scaling, n_color_bins, unit_normalise=True + ) + + raster_axis.scatter( + spike_times, + spike_depths, + c=rgba_colors, + s=point_size, + antialiased=True, + ) + + if not add_histogram_plot: + raster_axis.set_xlabel("time") + raster_axis.set_ylabel("y position") + return fig + + # Plot the histogram on the left-hand subplot + hist_axis.set_xlabel("count") + raster_axis.set_xlabel("time") + hist_axis.set_ylabel("y position") + + bin_centers, counts = processed_data.compute_activity_histogram( + weight_histogram_by_amplitude + ) + hist_axis.plot(counts, bin_centers, color="black", linewidth=1) + + return fig diff --git a/driftmap_viewer/mpl_plotting/__init__.py b/driftmap_viewer/mpl_plotting/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/driftmap_viewer/mpl_plotting/driftmapviewer_new.py b/driftmap_viewer/mpl_plotting/driftmapviewer_new.py deleted file mode 100644 index 025a585..0000000 --- a/driftmap_viewer/mpl_plotting/driftmapviewer_new.py +++ /dev/null @@ -1,296 +0,0 @@ - -from pathlib import Path -import matplotlib.axis -import scipy.signal -from spikeinterface.core import read_python -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt - -import matplotlib.pyplot as plt -from scipy import stats -from ks_extractors import kilosort_4, kilosort1_3 - - -def get_spikes_info(sorter_output, ks_version, exclude_noise): - """""" - if ks_version == "kilosort4": - spike_times, spike_amplitudes, spike_depths = kilosort_4.get_spikes_info_ks4( - sorter_output, exclude_noise - ) - else: - spike_times, spike_amplitudes, spike_depths = kilosort1_3.get_spike_info( - sorter_output, exclude_noise - ) - - return spike_times, spike_amplitudes, spike_depths - - - -# TODO: GAIN DOES NOTHING -# separate function to add histogram - -def get_drift_map_plot( - sorter_output: str | Path, - only_include_large_amplitude_spikes: bool = True, - decimate: None | int = None, - add_histogram_plot: bool = False, - weight_histogram_by_amplitude: bool = False, - exclude_noise: bool = False, - large_amplitude_only_segment_size: float = 800.0, -): - """ - """ - # ks_version = get_ks_version() - sorter_output = Path(sorter_output) - - ks_version = "kilosort4" - - spike_times, spike_amplitudes, spike_depths = get_spikes_info( - sorter_output, ks_version, exclude_noise - ) - - # TODO: this is super weird, can be imrpoved? - if log_transform_amplitudes: - spike_amplitudes = np.log10(spike_amplitudes) - - # Calculate the amplitude range for plotting first, so the scale is always the - # same across all options (e.g. decimation) which helps with interpretability. - amplitude_range_all_spikes = ( - spike_amplitudes.min(), - spike_amplitudes.max(), - ) - - if decimate: - spike_times = spike_times[:: decimate] - spike_amplitudes = spike_amplitudes[:: decimate] - spike_depths = spike_depths[:: decimate] - - if only_include_large_amplitude_spikes: - spike_times, spike_amplitudes, spike_depths = _filter_large_amplitude_spikes( - spike_times, spike_amplitudes, spike_depths, - large_amplitude_only_segment_size - ) - - # Setup axis and plot the raster drift map - fig = plt.figure(figsize=(10, 10 * (6 / 8))) - - if add_histogram_plot: - gs = fig.add_gridspec(1, 2, width_ratios=[1, 5]) - hist_axis = fig.add_subplot(gs[0]) - raster_axis = fig.add_subplot(gs[1], sharey=hist_axis) - else: - raster_axis = fig.add_subplot() - - _plot_kilosort_drift_map_raster( - spike_times, - spike_amplitudes, - spike_depths, - amplitude_range_all_spikes, - axis=raster_axis, - ) - - if not add_histogram_plot: - raster_axis.set_xlabel("time") - raster_axis.set_ylabel("y position") - return fig - - # If the histogram plot is requested, plot it alongside - # it's peak colouring, bounds display and drift point display. - hist_axis.set_xlabel("count") - raster_axis.set_xlabel("time") - hist_axis.set_ylabel("y position") - - bin_centers, counts = _compute_activity_histogram( - spike_amplitudes, spike_depths, weight_histogram_by_amplitude - ) - hist_axis.plot(counts, bin_centers, color="black", linewidth=1) - - return fig - - -def _plot_kilosort_drift_map_raster( - spike_times: np.ndarray, - spike_amplitudes: np.ndarray, - spike_depths: np.ndarray, - amplitude_range: np.ndarray | tuple, - axis: matplotlib.axes.Axes, -) -> None: - """ - Plot a drift raster plot in the kilosort style. - - This function was ported from Nick Steinmetz's `spikes` repository - MATLAB code, https://github.com/cortex-lab/spikes - - Parameters - ---------- - spike_times : np.ndarray - (num_spikes,) array of spike times. - spike_amplitudes : np.ndarray - (num_spikes,) array of corresponding spike amplitudes. - spike_depths : np.ndarray - (num_spikes,) array of corresponding spike depths. - amplitude_range : np.ndarray | tuple - (2,) array of min, max amplitude values for color binning. - axis : matplotlib.axes.Axes - Matplotlib axes object on which to plot the drift map. - """ - n_color_bins = 20 - marker_size = 0.5 - - color_bins = np.linspace(amplitude_range[0], amplitude_range[1], n_color_bins) - - colors = plt.get_cmap("gray")(np.linspace(0, 1, n_color_bins))[::-1] - - for bin_idx in range(n_color_bins - 1): - - spikes_in_amplitude_bin = np.logical_and( - spike_amplitudes >= color_bins[bin_idx], spike_amplitudes <= color_bins[bin_idx + 1] - ) - axis.scatter( - spike_times[spikes_in_amplitude_bin], - spike_depths[spikes_in_amplitude_bin], - color=colors[bin_idx], - s=marker_size, - antialiased=True, - ) - -def _compute_activity_histogram( - spike_amplitudes: np.ndarray, spike_depths: np.ndarray, weight_histogram_by_amplitude: bool -) -> tuple[np.ndarray, ...]: - """ - Compute the activity histogram for the kilosort drift map's left-side plot. - - Parameters - ---------- - spike_amplitudes : np.ndarray - (num_spikes,) array of spike amplitudes. - spike_depths : np.ndarray - (num_spikes,) array of spike depths. - weight_histogram_by_amplitude : bool - If `True`, the spike amplitudes are taken into consideration when generating the - histogram. The amplitudes are scaled to the range [0, 1] then summed for each bin, - to generate the histogram values. If `False`, counts (i.e. num spikes per bin) - are used. - - Returns - ------- - bin_centers : np.ndarray - The spatial bin centers (probe depth) for the histogram. - values : np.ndarray - The histogram values. If `weight_histogram_by_amplitude` is `False`, these - values represent are counts, otherwise they are counts weighted by amplitude. - """ - spike_amplitudes = spike_amplitudes.astype(np.float64) - - bin_um = 2 - bins = np.arange(spike_depths.min() - bin_um, spike_depths.max() + bin_um, bin_um) - values, bins = np.histogram(spike_depths, bins=bins) - bin_centers = (bins[:-1] + bins[1:]) / 2 - - if weight_histogram_by_amplitude: - bin_indices = np.digitize(spike_depths, bins, right=True) - 1 - values = np.zeros(bin_indices.max() + 1, dtype=np.float64) - scaled_spike_amplitudes = (spike_amplitudes - spike_amplitudes.min()) / np.ptp(spike_amplitudes) - np.add.at(values, bin_indices, scaled_spike_amplitudes) - - return bin_centers, values - - -def _color_histogram_peaks_and_detect_drift_events( - spike_times: np.ndarray, - spike_depths: np.ndarray, - counts: np.ndarray, - bin_centers: np.ndarray, - hist_axis: matplotlib.axes.Axes, -) -> np.ndarray: - """ - Given an activity histogram, color the peaks red (isolated peak) or - blue (peak overlaps with other peaks) and compute spatial drift - events for isolated peaks across time bins. - - This function was ported from Nick Steinmetz's `spikes` repository - MATLAB code, https://github.com/cortex-lab/spikes - - Parameters - ---------- - spike_times : np.ndarray - (num_spikes,) array of spike times. - spike_depths : np.ndarray - (num_spikes,) array of corresponding spike depths. - counts : np.ndarray - (num_bins,) array of histogram bin counts. - bin_centers : np.ndarray - (num_bins,) array of histogram bin centers. - hist_axis : matplotlib.axes.Axes - Axes on which the histogram is plot, to add peaks. - - Returns - ------- - drift_events : np.ndarray - A (num_drift_events, 3) array of drift events. The columns are - (time_position, spatial_position, drift_value). The drift - value is computed per time, spatial bin as the difference between - the median position of spikes in the bin, and the bin center. - """ - all_peak_indexes = scipy.signal.find_peaks( - counts, - )[0] - - # Filter low-frequency peaks, so they are not included in the - # step to determine whether peaks are overlapping (new step - # introduced in the port to python) - bin_above_freq_threshold = counts[all_peak_indexes] > 0.3 * spike_times[-1] - filtered_peak_indexes = all_peak_indexes[bin_above_freq_threshold] - - drift_events = [] - for idx, peak_index in enumerate(filtered_peak_indexes): - - peak_count = counts[peak_index] - - # Find the start and end of peak min/max bounds (5% of amplitude) - start_position = np.where(counts[:peak_index] < peak_count * 0.05)[0].max() - end_position = np.where(counts[peak_index:] < peak_count * 0.05)[0].min() + peak_index - - if ( # bounds include another, different histogram peak - idx > 0 - and start_position < filtered_peak_indexes[idx - 1] - or idx < filtered_peak_indexes.size - 1 - and end_position > filtered_peak_indexes[idx + 1] - ): - hist_axis.scatter(peak_count, bin_centers[peak_index], facecolors="none", edgecolors="blue") - continue - - else: - for position in [start_position, end_position]: - hist_axis.axhline(bin_centers[position], 0, counts.max(), color="grey", linestyle="--") - hist_axis.scatter(peak_count, bin_centers[peak_index], facecolors="none", edgecolors="red") - - # For isolated histogram peaks, detect the drift events, defined as - # difference between spatial bin center and median spike depth in the bin - # over 6 um (in time / spatial bins with at least 10 spikes). - depth_in_window = np.logical_and( - spike_depths > bin_centers[start_position], - spike_depths < bin_centers[end_position], - ) - current_spike_depths = spike_depths[depth_in_window] - current_spike_times = spike_times[depth_in_window] - - window_s = 10 - - all_time_bins = np.arange(0, np.ceil(spike_times[-1]).astype(int), window_s) - for time_bin in all_time_bins: - - spike_in_time_bin = np.logical_and( - current_spike_times >= time_bin, current_spike_times <= time_bin + window_s - ) - drift_size = bin_centers[peak_index] - np.median(current_spike_depths[spike_in_time_bin]) - - # 6 um is the hardcoded threshold for drift, and we want at least 10 spikes for the median calculation - bin_has_drift = np.abs(drift_size) > 6 and np.sum(spike_in_time_bin, dtype=np.int16) > 10 - if bin_has_drift: - drift_events.append((time_bin + window_s / 2, bin_centers[peak_index], drift_size)) - - drift_events = np.array(drift_events) - - return drift_events \ No newline at end of file diff --git a/examples/example_amplitudes.py b/examples/example_amplitudes.py new file mode 100644 index 0000000..e2011bd --- /dev/null +++ b/examples/example_amplitudes.py @@ -0,0 +1,43 @@ +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import spikeinterface as si + +from driftmap_viewer import DriftMapView, get_amplitudes + +# Getting the amplitudes across a set of sorting outputs can be useful to +# compute absolute amplitudes used for filtering spikes based +# on amplitude, or scaling color map values the same across plots. + +# Load the data. In this example we load as a sorting analyzer +# or from the raw kilosort output to demonstrate both methods +data_path = Path(__file__).parent / "example_data" +analyzer = si.load_sorting_analyzer(data_path / "analyzer.zarr") +sorting_output_path = data_path / "sorting" / "sorter_output" + +all_spike_amplitudes = get_amplitudes( + [analyzer, sorting_output_path], concatenate=False +) + +fig, axes = plt.subplots(1, 2) +for idx, amplitudes in enumerate(all_spike_amplitudes): + axes[idx].hist(amplitudes, bins=25) + axes[idx].set_title(f"Session: {idx}") + +plt.show() + +concat_spike_amplitudes = np.concatenate(all_spike_amplitudes) +min_cutoff, max_cutoff = concat_spike_amplitudes.min(), concat_spike_amplitudes.max() + +for path_or_analzyer in [analyzer, sorting_output_path]: + plotter = DriftMapView(analyzer) + + plot = plotter.drift_map_plot_matplotlib( + amplitude_scaling=(min_cutoff, max_cutoff), + n_color_bins=25, + filter_amplitude_mode=None, + exclude_noise="KSLabel", + ) + + plt.show() diff --git a/examples/example_data/analyzer.zarr/.zattrs b/examples/example_data/analyzer.zarr/.zattrs new file mode 100644 index 0000000..6d567cf --- /dev/null +++ b/examples/example_data/analyzer.zarr/.zattrs @@ -0,0 +1,10 @@ +{ + "settings": { + "return_in_uV": true + }, + "spikeinterface_info": { + "dev_mode": true, + "object": "SortingAnalyzer", + "version": "0.103.3" + } +} diff --git a/examples/example_data/analyzer.zarr/.zgroup b/examples/example_data/analyzer.zarr/.zgroup new file mode 100644 index 0000000..3f3fad2 --- /dev/null +++ b/examples/example_data/analyzer.zarr/.zgroup @@ -0,0 +1,3 @@ +{ + "zarr_format": 2 +} diff --git a/examples/example_data/analyzer.zarr/.zmetadata b/examples/example_data/analyzer.zarr/.zmetadata new file mode 100644 index 0000000..c8c24f0 --- /dev/null +++ b/examples/example_data/analyzer.zarr/.zmetadata @@ -0,0 +1,32300 @@ +{ + "metadata": { + ".zattrs": { + "settings": { + "return_in_uV": true + }, + "spikeinterface_info": { + "dev_mode": true, + "object": "SortingAnalyzer", + "version": "0.103.3" + } + }, + ".zgroup": { + "zarr_format": 2 + }, + "extensions/.zgroup": { + "zarr_format": 2 + }, + "extensions/random_spikes/.zattrs": { + "info": { + "class": "spikeinterface.core.analyzer_extension_core.ComputeRandomSpikes", + "module": "spikeinterface", + "version": "0.103.3" + }, + "params": { + "margin_size": null, + "max_spikes_per_unit": 1000000, + "method": "uniform", + "seed": null + }, + "run_info": { + "run_completed": true, + "runtime_s": 0.001311899977736175 + } + }, + "extensions/random_spikes/.zgroup": { + "zarr_format": 2 + }, + "extensions/random_spikes/random_spikes_indices/.zarray": { + "chunks": [ + 2683 + ], + "compressor": { + "blocksize": 0, + "clevel": 5, + "cname": "zstd", + "id": "blosc", + "shuffle": 2 + }, + "dtype": "