diff --git a/examples/get_started/example_bombcell_unit_labelling.ipynb b/examples/get_started/example_bombcell_unit_labelling.ipynb new file mode 100644 index 0000000000..8b18aaec90 --- /dev/null +++ b/examples/get_started/example_bombcell_unit_labelling.ipynb @@ -0,0 +1,262 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Bombcell unit labelling\n", + "\n", + "With this notebook you can:\n", + "- load a SortingAnalyzer\n", + "- compute required extensions\n", + "- label units based on quality thresholds\n", + "- generating and save summary plots\n", + "- save metrics and results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import spikeinterface as si\n", + "from spikeinterface.curation import (\n", + " bombcell_get_default_thresholds,\n", + " bombcell_label_units,\n", + " save_thresholds,\n", + " load_thresholds,\n", + ")\n", + "from spikeinterface.widgets import plot_unit_labelling_all" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### load a SortingAnalyzer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Change this to your analyzer path - you need to have already generated a sorting analyzer. see quickstart.py for how to do this\n", + "analyzer_path = \"/Users/jf5479/Downloads/M25_D18/kilosort4_sa\"\n", + "output_folder = Path(analyzer_path) / \"bombcell\"\n", + "\n", + "analyzer = si.load_sorting_analyzer(analyzer_path)\n", + "analyzer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### compute required extensions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Templates (required for template_metrics)\n", + "if not analyzer.has_extension(\"templates\"):\n", + " analyzer.compute(\"templates\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Template metrics\n", + "if not analyzer.has_extension(\"template_metrics\"):\n", + " analyzer.compute(\"template_metrics\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Quality metrics (and dependencies)\n", + "if not analyzer.has_extension(\"spike_amplitudes\"):\n", + " analyzer.compute(\"spike_amplitudes\")\n", + "\n", + "if not analyzer.has_extension(\"noise_levels\"):\n", + " analyzer.compute(\"noise_levels\")\n", + "\n", + "if not analyzer.has_extension(\"quality_metrics\"):\n", + " analyzer.compute(\"quality_metrics\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### get metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "qm = analyzer.get_extension(\"quality_metrics\").get_data()\n", + "tm = analyzer.get_extension(\"template_metrics\").get_data()\n", + "\n", + "print(f\"Quality metrics: {list(qm.columns)}\")\n", + "print(f\"Template metrics: {list(tm.columns)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### set labelling thresholds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use default thresholds\n", + "thresholds = bombcell_get_default_thresholds()\n", + "\n", + "# Or load from file:\n", + "# thresholds = load_thresholds(\"my_thresholds.json\")\n", + "\n", + "thresholds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Optionally modify thresholds\n", + "# thresholds[\"amplitude_median\"][\"min\"] = 50 # stricter\n", + "# thresholds[\"rp_contamination\"][\"max\"] = 0.05 # stricter" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Optionally set and load thresholds from a JSON file \n", + "# Load thresholds from saved JSON\n", + "thresholds = load_thresholds(output_folder / \"thresholds.json\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The JSON file format looks like:\n", + "```json\n", + "{\n", + " \"amplitude_median\": {\"min\": 40, \"max\": null},\n", + " \"num_positive_peaks\": {\"min\": null, \"max\": 2},\n", + " \"peak_to_trough_duration\": {\"min\": 0.0001, \"max\": 0.00115}\n", + "}\n", + "```\n", + "`null` in JSON becomes `np.nan` (threshold disabled)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### label units" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "unit_type, unit_type_string = bombcell_label_units(\n", + " quality_metrics=qm,\n", + " template_metrics=tm,\n", + " thresholds=thresholds,\n", + " label_non_somatic=True,\n", + " split_non_somatic_good_mua=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### generate summary plots" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plots = plot_unit_labelling_all(\n", + " analyzer,\n", + " unit_type,\n", + " unit_type_string,\n", + " quality_metrics=qm,\n", + " template_metrics=tm,\n", + " thresholds=thresholds,\n", + " save_folder=output_folder,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### save labelling thresholds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "save_thresholds(thresholds, output_folder / \"thresholds.json\")\n", + "\n", + "print(f\"Results saved to: {output_folder.absolute()}\")\n", + "print(\"\\nFiles:\")\n", + "for f in sorted(output_folder.glob(\"*\")):\n", + " print(f\" - {f.name}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index af7fb90f94..944dd59338 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -21,5 +21,18 @@ from .sortingview_curation import apply_sortingview_curation # automated curation +from .unit_labelling import ( + WAVEFORM_METRICS, + SPIKE_QUALITY_METRICS, + NON_SOMATIC_METRICS, + bombcell_get_default_thresholds, + bombcell_label_units, + apply_thresholds, + get_labelling_summary, + save_thresholds, + load_thresholds, + save_labelling_results, +) + from .model_based_curation import auto_label_units, load_model from .train_manual_curation import train_model, get_default_classifier_search_spaces diff --git a/src/spikeinterface/curation/default_thresholds.json b/src/spikeinterface/curation/default_thresholds.json new file mode 100644 index 0000000000..8e39d89179 --- /dev/null +++ b/src/spikeinterface/curation/default_thresholds.json @@ -0,0 +1,74 @@ +{ + "num_positive_peaks": { + "min": null, + "max": 2 + }, + "num_negative_peaks": { + "min": null, + "max": 1 + }, + "peak_to_trough_duration": { + "min": 0.0001, + "max": 0.00115 + }, + "waveform_baseline_flatness": { + "min": null, + "max": 0.5 + }, + "peak_after_to_trough_ratio": { + "min": null, + "max": 0.8 + }, + "exp_decay": { + "min": 0.01, + "max": 0.1 + }, + "amplitude_median": { + "min": 40, + "max": null + }, + "snr_bombcell": { + "min": 5, + "max": null + }, + "amplitude_cutoff": { + "min": null, + "max": 0.2 + }, + "num_spikes": { + "min": 300, + "max": null + }, + "rp_contamination": { + "min": null, + "max": 0.1 + }, + "presence_ratio": { + "min": 0.7, + "max": null + }, + "drift_ptp": { + "min": null, + "max": 100 + }, + "peak_before_to_trough_ratio": { + "min": null, + "max": 3 + }, + "peak_before_width": { + "min": 150, + "max": null + }, + "trough_width": { + "min": 200, + "max": null + }, + "peak_before_to_peak_after_ratio": { + "min": null, + "max": 3 + }, + "main_peak_to_trough_ratio": { + "min": null, + "max": 0.8 + } +} diff --git a/src/spikeinterface/curation/unit_labelling.py b/src/spikeinterface/curation/unit_labelling.py new file mode 100644 index 0000000000..814632da6e --- /dev/null +++ b/src/spikeinterface/curation/unit_labelling.py @@ -0,0 +1,430 @@ +""" +Unit labelling based on quality metrics (Bombcell). + +Unit Types: + 0 (NOISE): Failed waveform quality checks + 1 (GOOD): Passed all thresholds + 2 (MUA): Failed spike quality checks + 3 (NON_SOMA): Non-somatic units (axonal) +""" + +from __future__ import annotations + +import numpy as np +import pandas as pd +from typing import Optional + + +WAVEFORM_METRICS = [ + "num_positive_peaks", + "num_negative_peaks", + "peak_to_trough_duration", + "waveform_baseline_flatness", + "peak_after_to_trough_ratio", + "exp_decay", +] + +SPIKE_QUALITY_METRICS = [ + "amplitude_median", + "snr_bombcell", + "amplitude_cutoff", + "num_spikes", + "rp_contamination", + "presence_ratio", + "drift_ptp", +] + +NON_SOMATIC_METRICS = [ + "peak_before_to_trough_ratio", + "peak_before_width", + "trough_width", + "peak_before_to_peak_after_ratio", + "main_peak_to_trough_ratio", +] + + +def bombcell_get_default_thresholds() -> dict: + """ + Bombcell - Returns default thresholds for unit labelling. + + Each metric has 'min' and 'max' values. Use np.nan to disable a threshold (e.g. to ignore a metric completly + or to only have a min or a max threshold) + """ + # bombcell + return { + # Waveform quality (failures -> NOISE) + "num_positive_peaks": {"min": np.nan, "max": 2}, + "num_negative_peaks": {"min": np.nan, "max": 1}, + "peak_to_trough_duration": {"min": 0.0001, "max": 0.00115}, # seconds + "waveform_baseline_flatness": {"min": np.nan, "max": 0.5}, + "peak_after_to_trough_ratio": {"min": np.nan, "max": 0.8}, + "exp_decay": {"min": 0.01, "max": 0.1}, + # Spike quality (failures -> MUA) + "amplitude_median": {"min": 40, "max": np.nan}, # uV + "snr_bombcell": {"min": 5, "max": np.nan}, + "amplitude_cutoff": {"min": np.nan, "max": 0.2}, + "num_spikes": {"min": 300, "max": np.nan}, + "rp_contamination": {"min": np.nan, "max": 0.1}, + "presence_ratio": {"min": 0.7, "max": np.nan}, + "drift_ptp": {"min": np.nan, "max": 100}, # um + # Non-somatic detection + "peak_before_to_trough_ratio": {"min": np.nan, "max": 3}, + "peak_before_width": {"min": 150, "max": np.nan}, # us + "trough_width": {"min": 200, "max": np.nan}, # us + "peak_before_to_peak_after_ratio": {"min": np.nan, "max": 3}, + "main_peak_to_trough_ratio": {"min": np.nan, "max": 0.8}, + } + + +def _combine_metrics(quality_metrics, template_metrics): + """Combine quality_metrics and template_metrics into a single DataFrame.""" + if quality_metrics is None and template_metrics is None: + return None + if quality_metrics is None: + return template_metrics + if template_metrics is None: + return quality_metrics + return quality_metrics.join(template_metrics, how="outer") + + +def bombcell_label_units( + quality_metrics=None, + template_metrics=None, + thresholds: Optional[dict] = None, + label_non_somatic: bool = True, + split_non_somatic_good_mua: bool = False, +) -> tuple[np.ndarray, np.ndarray]: + """ + Bombcell - label units based on quality metrics and thresholds. + + Parameters + ---------- + quality_metrics : pd.DataFrame, optional + DataFrame with quality metrics (index = unit_ids). + template_metrics : pd.DataFrame, optional + DataFrame with template metrics (index = unit_ids). + thresholds : dict or None + Threshold dict: {"metric": {"min": val, "max": val}}. Use np.nan to disable. + label_non_somatic : bool + If True, detect non-somatic (axonal) units. + split_non_somatic_good_mua : bool + If True, split non-somatic into NON_SOMA_GOOD (3) and NON_SOMA_MUA (4). + + Returns + ------- + unit_type : np.ndarray + Numeric: 0=NOISE, 1=GOOD, 2=MUA, 3=NON_SOMA + unit_type_string : np.ndarray + String labels. + """ + combined_metrics = _combine_metrics(quality_metrics, template_metrics) + if combined_metrics is None: + raise ValueError("At least one of quality_metrics or template_metrics must be provided") + + if thresholds is None: + thresholds = bombcell_get_default_thresholds() + + n_units = len(combined_metrics) + unit_type = np.full(n_units, np.nan) + absolute_value_metrics = ["amplitude_median"] + + # NOISE: waveform failures + noise_mask = np.zeros(n_units, dtype=bool) + for metric_name in WAVEFORM_METRICS: + if metric_name not in combined_metrics.columns or metric_name not in thresholds: + continue + values = combined_metrics[metric_name].values + if metric_name in absolute_value_metrics: + values = np.abs(values) + thresh = thresholds[metric_name] + noise_mask |= np.isnan(values) + if not np.isnan(thresh["min"]): + noise_mask |= values < thresh["min"] + if not np.isnan(thresh["max"]): + noise_mask |= values > thresh["max"] + unit_type[noise_mask] = 0 + + # MUA: spike quality failures + mua_mask = np.zeros(n_units, dtype=bool) + for metric_name in SPIKE_QUALITY_METRICS: + if metric_name not in combined_metrics.columns or metric_name not in thresholds: + continue + values = combined_metrics[metric_name].values + if metric_name in absolute_value_metrics: + values = np.abs(values) + thresh = thresholds[metric_name] + valid_mask = np.isnan(unit_type) + if not np.isnan(thresh["min"]): + mua_mask |= valid_mask & ~np.isnan(values) & (values < thresh["min"]) + if not np.isnan(thresh["max"]): + mua_mask |= valid_mask & ~np.isnan(values) & (values > thresh["max"]) + unit_type[mua_mask & np.isnan(unit_type)] = 2 + + # GOOD: passed all checks + unit_type[np.isnan(unit_type)] = 1 + + # NON-SOMATIC + if label_non_somatic: + + def get_metric(name): + if name in combined_metrics.columns: + return combined_metrics[name].values + return np.full(n_units, np.nan) + + peak_before_width = get_metric("peak_before_width") + trough_width = get_metric("trough_width") + width_thresh_peak = thresholds.get("peak_before_width", {}).get("min", np.nan) + width_thresh_trough = thresholds.get("trough_width", {}).get("min", np.nan) + + narrow_peak = ( + ~np.isnan(peak_before_width) & (peak_before_width < width_thresh_peak) + if not np.isnan(width_thresh_peak) + else np.zeros(n_units, dtype=bool) + ) + narrow_trough = ( + ~np.isnan(trough_width) & (trough_width < width_thresh_trough) + if not np.isnan(width_thresh_trough) + else np.zeros(n_units, dtype=bool) + ) + width_conditions = narrow_peak & narrow_trough + + peak_before_to_trough = get_metric("peak_before_to_trough_ratio") + peak_before_to_peak_after = get_metric("peak_before_to_peak_after_ratio") + main_peak_to_trough = get_metric("main_peak_to_trough_ratio") + + ratio_thresh_pbt = thresholds.get("peak_before_to_trough_ratio", {}).get("max", np.nan) + ratio_thresh_pbpa = thresholds.get("peak_before_to_peak_after_ratio", {}).get("max", np.nan) + ratio_thresh_mpt = thresholds.get("main_peak_to_trough_ratio", {}).get("max", np.nan) + + large_initial_peak = ( + ~np.isnan(peak_before_to_trough) & (peak_before_to_trough > ratio_thresh_pbt) + if not np.isnan(ratio_thresh_pbt) + else np.zeros(n_units, dtype=bool) + ) + large_peak_ratio = ( + ~np.isnan(peak_before_to_peak_after) & (peak_before_to_peak_after > ratio_thresh_pbpa) + if not np.isnan(ratio_thresh_pbpa) + else np.zeros(n_units, dtype=bool) + ) + large_main_peak = ( + ~np.isnan(main_peak_to_trough) & (main_peak_to_trough > ratio_thresh_mpt) + if not np.isnan(ratio_thresh_mpt) + else np.zeros(n_units, dtype=bool) + ) + + # (ratio AND width) OR standalone main_peak_to_trough + ratio_conditions = large_initial_peak | large_peak_ratio + is_non_somatic = (ratio_conditions & width_conditions) | large_main_peak + + if split_non_somatic_good_mua: + unit_type[(unit_type == 1) & is_non_somatic] = 3 + unit_type[(unit_type == 2) & is_non_somatic] = 4 + else: + unit_type[(unit_type != 0) & is_non_somatic] = 3 + + # String labels + if split_non_somatic_good_mua: + labels = {0: "NOISE", 1: "GOOD", 2: "MUA", 3: "NON_SOMA_GOOD", 4: "NON_SOMA_MUA"} + else: + labels = {0: "NOISE", 1: "GOOD", 2: "MUA", 3: "NON_SOMA"} + + unit_type_string = np.array([labels.get(int(t), "UNKNOWN") for t in unit_type], dtype=object) + return unit_type.astype(int), unit_type_string + + +def apply_thresholds( + quality_metrics: pd.DataFrame, + thresholds: Optional[dict] = None, +) -> pd.DataFrame: + """ + Apply thresholds and return pass/fail status for each metric. + Useful for debugging classification results. + """ + if thresholds is None: + thresholds = bombcell_get_default_thresholds() + + results = {} + for metric_name, thresh in thresholds.items(): + if metric_name not in quality_metrics.columns: + continue + + values = quality_metrics[metric_name].values + n_units = len(values) + passes = np.ones(n_units, dtype=bool) + reasons = np.array([""] * n_units, dtype=object) + + nan_mask = np.isnan(values) + passes[nan_mask] = False + reasons[nan_mask] = "nan" + + if not np.isnan(thresh["min"]): + below_min = ~nan_mask & (values < thresh["min"]) + passes[below_min] = False + reasons[below_min] = "below_min" + + if not np.isnan(thresh["max"]): + above_max = ~nan_mask & (values > thresh["max"]) + passes[above_max] = False + reasons[above_max & (reasons == "")] = "above_max" + reasons[above_max & (reasons == "below_min")] = "below_min_and_above_max" + + results[f"{metric_name}_pass"] = passes + results[f"{metric_name}_fail_reason"] = reasons + + return pd.DataFrame(results, index=quality_metrics.index) + + +def get_labelling_summary(unit_type: np.ndarray, unit_type_string: np.ndarray) -> dict: + """Get counts and percentages for each unit type.""" + n_total = len(unit_type) + unique_types, counts = np.unique(unit_type, return_counts=True) + + summary = {"total_units": n_total, "counts": {}, "percentages": {}} + for utype, count in zip(unique_types, counts): + label = unit_type_string[unit_type == utype][0] + summary["counts"][label] = int(count) + summary["percentages"][label] = round(100 * count / n_total, 1) + + return summary + + +def save_thresholds(thresholds: dict, filepath) -> None: + """ + Save thresholds to a JSON file. + + Parameters + ---------- + thresholds : dict + Threshold dictionary from bombcell_get_default_thresholds() or modified version. + filepath : str or Path + Path to save the JSON file. + """ + import json + from pathlib import Path + + # Convert np.nan to None for JSON serialization + json_thresholds = {} + for metric_name, thresh in thresholds.items(): + json_thresholds[metric_name] = { + "min": None if (isinstance(thresh["min"], float) and np.isnan(thresh["min"])) else thresh["min"], + "max": None if (isinstance(thresh["max"], float) and np.isnan(thresh["max"])) else thresh["max"], + } + + filepath = Path(filepath) + with open(filepath, "w") as f: + json.dump(json_thresholds, f, indent=4) + + +def load_thresholds(filepath) -> dict: + """ + Load thresholds from a JSON file. + + Parameters + ---------- + filepath : str or Path + Path to the JSON file. + + Returns + ------- + thresholds : dict + Threshold dictionary compatible with bombcell_classify_units(). + """ + import json + from pathlib import Path + + filepath = Path(filepath) + with open(filepath, "r") as f: + json_thresholds = json.load(f) + + # Convert None to np.nan + thresholds = {} + for metric_name, thresh in json_thresholds.items(): + thresholds[metric_name] = { + "min": np.nan if thresh["min"] is None else thresh["min"], + "max": np.nan if thresh["max"] is None else thresh["max"], + } + + return thresholds + + +def save_labelling_results( + quality_metrics: pd.DataFrame, + unit_type: np.ndarray, + unit_type_string: np.ndarray, + thresholds: dict, + folder, + save_narrow: bool = True, + save_wide: bool = True, +) -> None: + """ + Save labelling results to CSV files. + + Parameters + ---------- + quality_metrics : pd.DataFrame + DataFrame with quality metrics (index = unit_ids). + unit_type : np.ndarray + Numeric unit type codes. + unit_type_string : np.ndarray + String labels for each unit. + thresholds : dict + Threshold dictionary used for labelling. + folder : str or Path + Folder to save the CSV files. + save_narrow : bool, default: True + Save narrow/tidy format (one row per unit-metric). + save_wide : bool, default: True + Save wide format (one row per unit, metrics as columns). + """ + from pathlib import Path + + folder = Path(folder) + folder.mkdir(parents=True, exist_ok=True) + + unit_ids = quality_metrics.index.values + + # Wide format: one row per unit + if save_wide: + wide_df = quality_metrics.copy() + wide_df.insert(0, "label", unit_type_string) + wide_df.insert(1, "label_code", unit_type) + wide_df.to_csv(folder / "labelling_results_wide.csv") + + # Narrow format: one row per unit-metric combination + if save_narrow: + rows = [] + for i, unit_id in enumerate(unit_ids): + label = unit_type_string[i] + label_code = unit_type[i] + for metric_name in quality_metrics.columns: + if metric_name not in thresholds: + continue + value = quality_metrics.loc[unit_id, metric_name] + thresh = thresholds[metric_name] + thresh_min = thresh.get("min", np.nan) + thresh_max = thresh.get("max", np.nan) + + # Determine pass/fail + passed = True + if np.isnan(value): + passed = False + elif not np.isnan(thresh_min) and value < thresh_min: + passed = False + elif not np.isnan(thresh_max) and value > thresh_max: + passed = False + + rows.append( + { + "unit_id": unit_id, + "label": label, + "label_code": label_code, + "metric_name": metric_name, + "value": value, + "threshold_min": None if np.isnan(thresh_min) else thresh_min, + "threshold_max": None if np.isnan(thresh_max) else thresh_max, + "passed": passed, + } + ) + + narrow_df = pd.DataFrame(rows) + narrow_df.to_csv(folder / "labelling_results_narrow.csv", index=False) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index c6b07da52e..8c6339b773 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -182,6 +182,114 @@ class SNR(BaseMetric): depend_on = ["noise_levels", "templates"] +def compute_snrs_bombcell( + sorting_analyzer, + unit_ids=None, + peak_sign: str = "neg", + baseline_window_ms: float = 0.5, +): + """ + Compute signal to noise ratio using BombCell method. + + This differs from the standard SNR by using: + - Signal: Max absolute value of raw waveforms on peak channel + - Noise: MAD (Median Absolute Deviation) of baseline samples from waveforms + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + unit_ids : list or None + The list of unit ids to compute the SNR. If None, all units are used. + peak_sign : "neg" | "pos" | "both", default: "neg" + The sign of the template to compute best channels. + baseline_window_ms : float, default: 0.5 + Duration in ms at the start of the waveform to use as baseline for noise calculation. + + Returns + ------- + snrs : dict + Computed signal to noise ratio for each unit. + + Notes + ----- + This implementation follows the BombCell methodology: + - Signal is the maximum absolute amplitude of raw waveforms on the peak channel + - Noise is computed as MAD of baseline samples (first N samples of each waveform) + + Requires the "waveforms" extension to be computed. + """ + if not sorting_analyzer.has_extension("waveforms"): + raise ValueError( + "The 'waveforms' extension is required for compute_snrs_bombcell. " + "Please compute it first with: analyzer.compute('waveforms')" + ) + + if unit_ids is None: + unit_ids = sorting_analyzer.unit_ids + + waveforms_ext = sorting_analyzer.get_extension("waveforms") + nbefore = waveforms_ext.nbefore + sampling_frequency = sorting_analyzer.sampling_frequency + + # Calculate baseline samples from ms + baseline_samples = int(baseline_window_ms / 1000 * sampling_frequency) + baseline_samples = min(baseline_samples, nbefore) # Can't exceed nbefore + + # Get peak channel for each unit from templates + extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign) + + snrs = {} + for unit_id in unit_ids: + # Get waveforms for this unit (num_spikes, num_samples, num_channels) + waveforms = waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=False) + + if waveforms is None or len(waveforms) == 0: + snrs[unit_id] = np.nan + continue + + # Get peak channel index + peak_chan_id = extremum_channels_ids[unit_id] + if sorting_analyzer.is_sparse(): + chan_ids = sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] + if peak_chan_id not in chan_ids: + snrs[unit_id] = np.nan + continue + peak_chan_idx = np.where(chan_ids == peak_chan_id)[0][0] + else: + peak_chan_idx = sorting_analyzer.channel_ids_to_indices([peak_chan_id])[0] + + # Extract waveforms on peak channel + waveforms_peak = waveforms[:, :, peak_chan_idx] # (num_spikes, num_samples) + + # Signal: max absolute value across all spikes + signal = np.max(np.abs(waveforms_peak)) + + # Noise: MAD of baseline samples (first N samples of each waveform) + baseline_samples_all = waveforms_peak[:, :baseline_samples].flatten() + median_baseline = np.median(baseline_samples_all) + noise = np.median(np.abs(baseline_samples_all - median_baseline)) + + # Calculate SNR (avoid division by zero) + if noise > 0: + snrs[unit_id] = signal / noise + else: + snrs[unit_id] = np.nan + + return snrs + + +class SNRBombcell(BaseMetric): + metric_name = "snr_bombcell" + metric_function = compute_snrs_bombcell + metric_params = {"peak_sign": "neg", "baseline_window_ms": 0.5} + metric_columns = {"snr_bombcell": float} + metric_descriptions = { + "snr_bombcell": "Signal to noise ratio using BombCell method (raw waveform max / baseline MAD)." + } + depend_on = ["waveforms", "templates"] + + def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5, min_isi_ms=0): """ Calculate Inter-Spike Interval (ISI) violations. @@ -752,6 +860,7 @@ def compute_amplitude_cutoffs( num_histogram_bins=500, histogram_smoothing_value=3, amplitudes_bins_min_ratio=5, + plot_details=False, ): """ Calculate approximate fraction of spikes missing from a distribution of amplitudes. @@ -770,6 +879,9 @@ def compute_amplitude_cutoffs( The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN. + plot_details : bool, default: True + If True, generate diagnostic plots for each unit showing amplitude histogram + and gaussian fit. Hardcoded ON for debugging. Returns ------- @@ -807,13 +919,38 @@ def compute_amplitude_cutoffs( amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True) + # Get spike times for scatter plots if plot_details is enabled + spike_times_by_units = None + if plot_details: + sorting = sorting_analyzer.sorting + fs = sorting_analyzer.sampling_frequency + # Get spike times by unit (concatenated across segments) + spike_times_by_units = {} + for unit_id in unit_ids: + all_spike_times = [] + time_offset = 0.0 + for seg_idx in range(sorting_analyzer.get_num_segments()): + spike_train = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=seg_idx) + spike_times_s = spike_train / fs + time_offset + all_spike_times.append(spike_times_s) + time_offset += sorting_analyzer.get_num_samples(seg_idx) / fs + spike_times_by_units[unit_id] = np.concatenate(all_spike_times) if all_spike_times else np.array([]) + for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] if invert_amplitudes: amplitudes = -amplitudes + spike_times = spike_times_by_units[unit_id] if spike_times_by_units is not None else None + all_fraction_missing[unit_id] = amplitude_cutoff( - amplitudes, num_histogram_bins, histogram_smoothing_value, amplitudes_bins_min_ratio + amplitudes, + num_histogram_bins, + histogram_smoothing_value, + amplitudes_bins_min_ratio, + spike_times=spike_times, + unit_id=unit_id, + plot_details=plot_details, ) if np.any(np.isnan(list(all_fraction_missing.values()))): @@ -829,6 +966,7 @@ class AmplitudeCutoff(BaseMetric): "num_histogram_bins": 100, "histogram_smoothing_value": 3, "amplitudes_bins_min_ratio": 5, + "plot_details": False, } metric_columns = {"amplitude_cutoff": float} metric_descriptions = { @@ -1295,6 +1433,7 @@ class SDRatio(BaseMetric): FiringRate, PresenceRatio, SNR, + SNRBombcell, ISIViolation, RPViolation, SlidingRPViolation, @@ -1421,7 +1560,17 @@ def isi_violations(spike_trains, total_duration_s, isi_threshold_s=0.0015, min_i return isi_violations_ratio, isi_violations_rate, isi_violations_count -def amplitude_cutoff(amplitudes, num_histogram_bins=500, histogram_smoothing_value=3, amplitudes_bins_min_ratio=5): +def amplitude_cutoff( + amplitudes, + num_histogram_bins=500, + histogram_smoothing_value=3, + amplitudes_bins_min_ratio=5, + spike_times=None, + unit_id=None, + plot_details=False, + ax_scatter=None, + ax_hist=None, +): """ Calculate approximate fraction of spikes missing from a distribution of amplitudes. @@ -1439,6 +1588,18 @@ def amplitude_cutoff(amplitudes, num_histogram_bins=500, histogram_smoothing_val The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN. + spike_times : ndarray_like or None, default: None + The spike times (in seconds) for this unit. Used for plotting scatter plot. + unit_id : any, default: None + The unit ID for labeling plots. + plot_details : bool, default: True + If True, generate diagnostic plots showing amplitude histogram and gaussian fit. + Hardcoded ON for debugging. + ax_scatter : matplotlib axis or None, default: None + Axis for scatter plot (spike times vs amplitudes). If None and plot_details=True, + a new figure is created. + ax_hist : matplotlib axis or None, default: None + Axis for histogram plot. If None and plot_details=True, uses same figure. Returns ------- @@ -1471,6 +1632,102 @@ def amplitude_cutoff(amplitudes, num_histogram_bins=500, histogram_smoothing_val fraction_missing = np.sum(pdf[G:]) * bin_size fraction_missing = np.min([fraction_missing, 0.5]) + # Plot details for debugging (similar to MATLAB BombCell) + if plot_details: + import matplotlib.pyplot as plt + + # Create figure if no axes provided + if ax_scatter is None and ax_hist is None: + fig, axes = plt.subplots(1, 2, figsize=(12, 5)) + ax_scatter = axes[0] + ax_hist = axes[1] + created_figure = True + else: + created_figure = False + + # Colors matching MATLAB BombCell style + main_color = [0, 0.35, 0.71] # Blue + cutoff_color = [0.5430, 0, 0.5430] # Purple + fit_color = "red" + + # Plot 1: Scatter plot of spike times vs amplitudes (if spike_times provided) + if ax_scatter is not None and spike_times is not None: + ax_scatter.scatter(spike_times, amplitudes, s=4, c=[main_color], alpha=0.5) + + # Add outlier threshold line (using IQR method like MATLAB) + q1, q3 = np.percentile(amplitudes, [25, 75]) + iqr = q3 - q1 + iqr_threshold = 4 # Same as MATLAB default + outlier_line = q3 + iqr_threshold * iqr + + ylims = ax_scatter.get_ylim() + xlims = ax_scatter.get_xlim() + + ax_scatter.axhline(outlier_line, color=cutoff_color, linewidth=1.5) + ax_scatter.text( + xlims[1] * 0.98, + outlier_line * 1.02, + "Outlier Threshold", + ha="right", + va="bottom", + color=cutoff_color, + fontweight="bold", + fontsize=8, + ) + + ax_scatter.set_xlabel("Time (s)") + ax_scatter.set_ylabel("Amplitude scaling factor") + title_str = f"Unit {unit_id}" if unit_id is not None else "Amplitudes over time" + ax_scatter.set_title(title_str) + ax_scatter.spines["top"].set_visible(False) + ax_scatter.spines["right"].set_visible(False) + + elif ax_scatter is not None: + ax_scatter.text( + 0.5, + 0.5, + "Spike times not provided", + ha="center", + va="center", + transform=ax_scatter.transAxes, + ) + ax_scatter.set_title("Scatter plot requires spike_times") + + # Plot 2: Histogram with gaussian fit + if ax_hist is not None: + # Plot histogram as horizontal bars (like MATLAB) + bin_centers = (b[:-1] + b[1:]) / 2 + ax_hist.barh(bin_centers, h, height=bin_size * 0.9, color=main_color, alpha=0.7, label="Histogram") + + # Plot smoothed PDF (gaussian fit) + ax_hist.plot(pdf, support, color=fit_color, linewidth=2, label="Smoothed PDF") + + # Mark the cutoff point G + cutoff_amplitude = support[G] + ax_hist.axhline(cutoff_amplitude, color=cutoff_color, linestyle="--", linewidth=1.5, label="Cutoff") + + # Mark the peak + peak_amplitude = support[peak_index] + ax_hist.axhline(peak_amplitude, color="green", linestyle=":", linewidth=1.5, label="Peak") + + ax_hist.set_xlabel("Density") + ax_hist.set_ylabel("Amplitude") + + # Add percent missing text + rounded_p = f"{fraction_missing * 100:.1f}%" + title_str = f"% missing spikes: {rounded_p}" + if unit_id is not None: + title_str = f"Unit {unit_id}\n{title_str}" + ax_hist.set_title(title_str, color=[0.7, 0.7, 0.7]) + + ax_hist.legend(loc="upper right", fontsize=8) + ax_hist.spines["top"].set_visible(False) + ax_hist.spines["right"].set_visible(False) + + if created_figure: + plt.tight_layout() + plt.show() + return fraction_missing diff --git a/src/spikeinterface/metrics/template/metrics.py b/src/spikeinterface/metrics/template/metrics.py index a1af1de348..53148fac85 100644 --- a/src/spikeinterface/metrics/template/metrics.py +++ b/src/spikeinterface/metrics/template/metrics.py @@ -2,65 +2,534 @@ import numpy as np from collections import namedtuple - +from scipy.signal import find_peaks, savgol_filter from spikeinterface.core.analyzer_extension_core import BaseMetric -def get_trough_and_peak_idx(template): +def get_trough_and_peak_idx( + template, min_thresh_detect_peaks_troughs=0.4, smooth=True, smooth_window_frac=0.1, smooth_polyorder=3 +): """ - Return the indices into the input template of the detected trough - (minimum of template) and peak (maximum of template, after trough). - Assumes negative trough and positive peak. + Detect troughs and peaks in a template waveform and return detailed information + about each detected feature. Parameters ---------- - template: numpy.ndarray + template : numpy.ndarray The 1D template waveform + min_thresh_detect_peaks_troughs : float, default: 0.4 + Minimum prominence threshold as a fraction of the template's absolute max value + smooth : bool, default: True + Whether to apply smoothing before peak detection + smooth_window_frac : float, default: 0.1 + Smoothing window length as a fraction of template length (0.05-0.2 recommended) + smooth_polyorder : int, default: 3 + Polynomial order for Savitzky-Golay filter (must be < window_length) Returns ------- - trough_idx: int - The index of the trough - peak_idx: int - The index of the peak + troughs : dict + Dictionary containing: + - "indices": array of all trough indices + - "values": array of all trough values + - "prominences": array of all trough prominences + - "widths": array of all trough widths + - "main_idx": index of the main trough (most prominent) + - "main_loc": location (sample index) of the main trough in template + peaks_before : dict + Dictionary containing peaks detected before the main trough (initial peaks): + - "indices": array of all peak indices (in original template coordinates) + - "values": array of all peak values + - "prominences": array of all peak prominences + - "widths": array of all peak widths + - "main_idx": index of the main peak (most prominent) + - "main_loc": location (sample index) of the main peak in template + peaks_after : dict + Dictionary containing peaks detected after the main trough (repolarization peaks): + - "indices": array of all peak indices (in original template coordinates) + - "values": array of all peak values + - "prominences": array of all peak prominences + - "widths": array of all peak widths + - "main_idx": index of the main peak (most prominent) + - "main_loc": location (sample index) of the main peak in template """ assert template.ndim == 1 - trough_idx = np.argmin(template) - peak_idx = trough_idx + np.argmax(template[trough_idx:]) - return trough_idx, peak_idx + # Save original for plotting + template_original = template.copy() + + # Smooth template to reduce noise while preserving peaks using Savitzky-Golay filter + if smooth: + window_length = int(len(template) * smooth_window_frac) // 2 * 2 + 1 + window_length = max(smooth_polyorder + 2, window_length) # Must be > polyorder + template = savgol_filter(template, window_length=window_length, polyorder=smooth_polyorder) + + # Initialize empty result dictionaries + empty_dict = { + "indices": np.array([], dtype=int), + "values": np.array([]), + "prominences": np.array([]), + "widths": np.array([]), + "main_idx": None, + "main_loc": None, + } -######################################################################################### -# Single-channel metrics -def get_peak_to_valley(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs) -> float: + # Get min prominence to detect peaks and troughs relative to template abs max value + min_prominence = min_thresh_detect_peaks_troughs * np.nanmax(np.abs(template)) + + # --- Find troughs (by inverting waveform and using find_peaks) --- + trough_locs, trough_props = find_peaks(-template, prominence=min_prominence, width=0) + + if len(trough_locs) == 0: + # Fallback: use global minimum + trough_locs = np.array([np.nanargmin(template)]) + trough_props = {"prominences": np.array([np.nan]), "widths": np.array([np.nan])} + + # Determine main trough (most prominent, or first if no valid prominences) + trough_prominences = trough_props.get("prominences", np.array([])) + if len(trough_prominences) > 0 and not np.all(np.isnan(trough_prominences)): + main_trough_idx = np.nanargmax(trough_prominences) + else: + main_trough_idx = 0 + + main_trough_loc = trough_locs[main_trough_idx] + + troughs = { + "indices": trough_locs, + "values": template[trough_locs], + "prominences": trough_props.get("prominences", np.full(len(trough_locs), np.nan)), + "widths": trough_props.get("widths", np.full(len(trough_locs), np.nan)), + "main_idx": main_trough_idx, + "main_loc": main_trough_loc, + } + + # --- Find peaks before the main trough --- + if main_trough_loc > 3: + template_before = template[:main_trough_loc] + + # Try with original prominence + peak_locs_before, peak_props_before = find_peaks(template_before, prominence=min_prominence, width=0) + + # If no peaks found, try with lower prominence (keep only max peak) + if len(peak_locs_before) == 0: + lower_prominence = 0.075 * min_thresh_detect_peaks_troughs * np.nanmax(np.abs(template)) + peak_locs_before, peak_props_before = find_peaks(template_before, prominence=lower_prominence, width=0) + # Keep only the most prominent peak when using lower threshold + if len(peak_locs_before) > 1: + prominences = peak_props_before.get("prominences", np.array([])) + if len(prominences) > 0 and not np.all(np.isnan(prominences)): + max_idx = np.nanargmax(prominences) + peak_locs_before = np.array([peak_locs_before[max_idx]]) + peak_props_before = { + "prominences": np.array([prominences[max_idx]]), + "widths": np.array([peak_props_before.get("widths", np.array([np.nan]))[max_idx]]), + } + + # If still no peaks found, fall back to argmax + if len(peak_locs_before) == 0: + peak_locs_before = np.array([np.nanargmax(template_before)]) + peak_props_before = {"prominences": np.array([np.nan]), "widths": np.array([np.nan])} + + peak_prominences_before = peak_props_before.get("prominences", np.array([])) + if len(peak_prominences_before) > 0 and not np.all(np.isnan(peak_prominences_before)): + main_peak_before_idx = np.nanargmax(peak_prominences_before) + else: + main_peak_before_idx = 0 + + peaks_before = { + "indices": peak_locs_before, + "values": template[peak_locs_before], + "prominences": peak_props_before.get("prominences", np.full(len(peak_locs_before), np.nan)), + "widths": peak_props_before.get("widths", np.full(len(peak_locs_before), np.nan)), + "main_idx": main_peak_before_idx, + "main_loc": peak_locs_before[main_peak_before_idx], + } + else: + peaks_before = empty_dict.copy() + + # --- Find peaks after the main trough (repolarization peaks) --- + if main_trough_loc < len(template) - 3: + template_after = template[main_trough_loc:] + + # Try with original prominence + peak_locs_after, peak_props_after = find_peaks(template_after, prominence=min_prominence, width=0) + + # If no peaks found, try with lower prominence (keep only max peak) + if len(peak_locs_after) == 0: + lower_prominence = 0.075 * min_thresh_detect_peaks_troughs * np.nanmax(np.abs(template)) + peak_locs_after, peak_props_after = find_peaks(template_after, prominence=lower_prominence, width=0) + # Keep only the most prominent peak when using lower threshold + if len(peak_locs_after) > 1: + prominences = peak_props_after.get("prominences", np.array([])) + if len(prominences) > 0 and not np.all(np.isnan(prominences)): + max_idx = np.nanargmax(prominences) + peak_locs_after = np.array([peak_locs_after[max_idx]]) + peak_props_after = { + "prominences": np.array([prominences[max_idx]]), + "widths": np.array([peak_props_after.get("widths", np.array([np.nan]))[max_idx]]), + } + + # If still no peaks found, fall back to argmax + if len(peak_locs_after) == 0: + peak_locs_after = np.array([np.nanargmax(template_after)]) + peak_props_after = {"prominences": np.array([np.nan]), "widths": np.array([np.nan])} + + # Convert to original template coordinates + peak_locs_after_abs = peak_locs_after + main_trough_loc + + peak_prominences_after = peak_props_after.get("prominences", np.array([])) + if len(peak_prominences_after) > 0 and not np.all(np.isnan(peak_prominences_after)): + main_peak_after_idx = np.nanargmax(peak_prominences_after) + else: + main_peak_after_idx = 0 + + peaks_after = { + "indices": peak_locs_after_abs, + "values": template[peak_locs_after_abs], + "prominences": peak_props_after.get("prominences", np.full(len(peak_locs_after), np.nan)), + "widths": peak_props_after.get("widths", np.full(len(peak_locs_after), np.nan)), + "main_idx": main_peak_after_idx, + "main_loc": peak_locs_after_abs[main_peak_after_idx], + } + else: + peaks_after = empty_dict.copy() + + # Quick visualization (set to True for debugging) + _plot = False # QQ set to false + if _plot: + import matplotlib.pyplot as plt + + # Old simple method for comparison (argmin/argmax) + old_trough_idx = np.nanargmin(template) + old_peak_idx = np.nanargmax(template[old_trough_idx:]) + old_trough_idx + + fig, ax = plt.subplots(figsize=(10, 5)) + ax.plot(template_original, color="lightgray", lw=1, label="original (noisy)") + ax.plot(template, "k-", lw=1.5, label="smoothed") + + # Plot old method (simple argmin/argmax) + ax.axvline(old_trough_idx, color="gray", ls="--", alpha=0.5, label="old trough (argmin)") + ax.axvline(old_peak_idx, color="gray", ls=":", alpha=0.5, label="old peak (argmax after trough)") + + # Plot all detected troughs + ax.scatter(troughs["indices"], troughs["values"], c="blue", s=50, marker="v", zorder=5, label="troughs") + if troughs["main_loc"] is not None: + ax.scatter( + troughs["main_loc"], + template[troughs["main_loc"]], + c="blue", + s=150, + marker="v", + edgecolors="red", + linewidths=2, + zorder=6, + label="main trough", + ) + + # Plot all peaks before + if len(peaks_before["indices"]) > 0: + ax.scatter( + peaks_before["indices"], + peaks_before["values"], + c="green", + s=50, + marker="^", + zorder=5, + label="peaks before", + ) + if peaks_before["main_loc"] is not None: + ax.scatter( + peaks_before["main_loc"], + template[peaks_before["main_loc"]], + c="green", + s=150, + marker="^", + edgecolors="red", + linewidths=2, + zorder=6, + label="main peak before", + ) + + # Plot all peaks after + if len(peaks_after["indices"]) > 0: + ax.scatter( + peaks_after["indices"], + peaks_after["values"], + c="orange", + s=50, + marker="^", + zorder=5, + label="peaks after", + ) + if peaks_after["main_loc"] is not None: + ax.scatter( + peaks_after["main_loc"], + template[peaks_after["main_loc"]], + c="orange", + s=150, + marker="^", + edgecolors="red", + linewidths=2, + zorder=6, + label="main peak after", + ) + + ax.axhline(0, color="gray", ls="-", alpha=0.3) + ax.set_xlabel("Sample") + ax.set_ylabel("Amplitude") + ax.legend(loc="best", fontsize=8) + ax.set_title(f"Trough/Peak Detection (prominence threshold: {min_thresh_detect_peaks_troughs})") + plt.tight_layout() + plt.show() + + return troughs, peaks_before, peaks_after + + +def get_waveform_duration(template, sampling_frequency, troughs, peaks_before, peaks_after, **kwargs): """ - Return the peak to valley duration in seconds of input waveforms. + Calculate waveform duration from the main extremum to the next extremum. + + The duration is measured from the largest absolute feature (main trough or main peak) + to the next extremum. For typical negative-first waveforms, this is trough-to-peak. + For positive-first waveforms, this is peak-to-trough. Parameters ---------- - template_single: numpy.ndarray + template : numpy.ndarray The 1D template waveform sampling_frequency : float - The sampling frequency of the template - trough_idx: int, default: None - The index of the trough - peak_idx: int, default: None - The index of the peak + The sampling frequency in Hz + troughs : dict + Trough detection results from get_trough_and_peak_idx + peaks_before : dict + Peak before trough results from get_trough_and_peak_idx + peaks_after : dict + Peak after trough results from get_trough_and_peak_idx Returns ------- - ptv: float - The peak to valley duration in seconds + waveform_duration_us : float + Waveform duration in microseconds """ - if trough_idx is None or peak_idx is None: - trough_idx, peak_idx = get_trough_and_peak_idx(template_single) - ptv = (peak_idx - trough_idx) / sampling_frequency - return ptv + # Get main locations and values + trough_loc = troughs["main_loc"] + trough_val = template[trough_loc] if trough_loc is not None else None + + peak_before_loc = peaks_before["main_loc"] + peak_before_val = template[peak_before_loc] if peak_before_loc is not None else None + + peak_after_loc = peaks_after["main_loc"] + peak_after_val = template[peak_after_loc] if peak_after_loc is not None else None + + # Find the main extremum (largest absolute value) + candidates = [] + if trough_loc is not None and trough_val is not None: + candidates.append(("trough", trough_loc, abs(trough_val))) + if peak_before_loc is not None and peak_before_val is not None: + candidates.append(("peak_before", peak_before_loc, abs(peak_before_val))) + if peak_after_loc is not None and peak_after_val is not None: + candidates.append(("peak_after", peak_after_loc, abs(peak_after_val))) + + if len(candidates) == 0: + return np.nan -def get_peak_trough_ratio(template_single, sampling_frequency=None, trough_idx=None, peak_idx=None, **kwargs) -> float: + # Sort by absolute value to find main extremum + candidates.sort(key=lambda x: x[2], reverse=True) + main_type, main_loc, _ = candidates[0] + + # Find the next extremum after the main one + if main_type == "trough": + # Main is trough, next is peak_after + if peak_after_loc is not None: + duration_samples = abs(peak_after_loc - main_loc) + elif peak_before_loc is not None: + duration_samples = abs(main_loc - peak_before_loc) + else: + return np.nan + elif main_type == "peak_before": + # Main is peak before, next is trough + if trough_loc is not None: + duration_samples = abs(trough_loc - main_loc) + else: + return np.nan + else: # peak_after + # Main is peak after, previous is trough + if trough_loc is not None: + duration_samples = abs(main_loc - trough_loc) + else: + return np.nan + + # Convert to microseconds + waveform_duration_us = (duration_samples / sampling_frequency) * 1e6 + + return waveform_duration_us + + +def get_waveform_ratios(template, troughs, peaks_before, peaks_after, **kwargs): """ - Return the peak to trough ratio of input waveforms. + Calculate various waveform amplitude ratios. + + Parameters + ---------- + template : numpy.ndarray + The 1D template waveform + troughs : dict + Trough detection results from get_trough_and_peak_idx + peaks_before : dict + Peak before trough results from get_trough_and_peak_idx + peaks_after : dict + Peak after trough results from get_trough_and_peak_idx + + Returns + ------- + ratios : dict + Dictionary containing: + - "peak_before_to_trough_ratio": ratio of peak before to trough amplitude + - "peak_after_to_trough_ratio": ratio of peak after to trough amplitude + - "peak_before_to_peak_after_ratio": ratio of peak before to peak after amplitude + - "main_peak_to_trough_ratio": ratio of larger peak to trough amplitude + """ + # Get absolute amplitudes + trough_amp = abs(template[troughs["main_loc"]]) if troughs["main_loc"] is not None else np.nan + peak_before_amp = abs(template[peaks_before["main_loc"]]) if peaks_before["main_loc"] is not None else np.nan + peak_after_amp = abs(template[peaks_after["main_loc"]]) if peaks_after["main_loc"] is not None else np.nan + + def safe_ratio(a, b): + if np.isnan(a) or np.isnan(b) or b == 0: + return np.nan + return a / b + + ratios = { + "peak_before_to_trough_ratio": safe_ratio(peak_before_amp, trough_amp), + "peak_after_to_trough_ratio": safe_ratio(peak_after_amp, trough_amp), + "peak_before_to_peak_after_ratio": safe_ratio(peak_before_amp, peak_after_amp), + "main_peak_to_trough_ratio": safe_ratio( + ( + max(peak_before_amp, peak_after_amp) + if not (np.isnan(peak_before_amp) and np.isnan(peak_after_amp)) + else np.nan + ), + trough_amp, + ), + } + + return ratios + + +def get_waveform_baseline_flatness(template, sampling_frequency, **kwargs): + """ + Compute the baseline flatness of the waveform. + + This metric measures the ratio of the max absolute amplitude in the baseline + window to the max absolute amplitude of the whole waveform. A lower value + indicates a flat baseline (expected for good units). + + Parameters + ---------- + template : numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency in Hz + **kwargs : Required kwargs: + - baseline_window_ms : tuple of (start_ms, end_ms) defining the baseline window + relative to waveform start. Default is (0, 0.5) for first 0.5ms. + + Returns + ------- + baseline_flatness : float + Ratio of max(abs(baseline)) / max(abs(waveform)). Lower = flatter baseline. + """ + baseline_window_ms = kwargs.get("baseline_window_ms", (0.0, 0.5)) + + if baseline_window_ms is None: + return np.nan + + start_ms, end_ms = baseline_window_ms + start_idx = int(start_ms / 1000 * sampling_frequency) + end_idx = int(end_ms / 1000 * sampling_frequency) + + # Clamp to valid range + start_idx = max(0, start_idx) + end_idx = min(len(template), end_idx) + + if end_idx <= start_idx: + return np.nan + + baseline_segment = template[start_idx:end_idx] + + if len(baseline_segment) == 0: + return np.nan + + max_baseline = np.nanmax(np.abs(baseline_segment)) + max_waveform = np.nanmax(np.abs(template)) + + if max_waveform == 0 or np.isnan(max_waveform): + return np.nan + + baseline_flatness = max_baseline / max_waveform + + return baseline_flatness + + +def get_waveform_widths(template, sampling_frequency, troughs, peaks_before, peaks_after, **kwargs): + """ + Get the widths of the main trough and peaks in microseconds. + + Parameters + ---------- + template : numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency in Hz + troughs : dict + Trough detection results from get_trough_and_peak_idx + peaks_before : dict + Peak before trough results from get_trough_and_peak_idx + peaks_after : dict + Peak after trough results from get_trough_and_peak_idx + + Returns + ------- + widths : dict + Dictionary containing: + - "trough_width_us": width of main trough in microseconds + - "peak_before_width_us": width of main peak before trough in microseconds + - "peak_after_width_us": width of main peak after trough in microseconds + """ + + def get_main_width(feature_dict): + if feature_dict["main_idx"] is None: + return np.nan + widths = feature_dict.get("widths", np.array([])) + if len(widths) == 0: + return np.nan + main_idx = feature_dict["main_idx"] + if main_idx < len(widths): + return widths[main_idx] + return np.nan + + # Convert from samples to microseconds + samples_to_us = 1e6 / sampling_frequency + + trough_width = get_main_width(troughs) + peak_before_width = get_main_width(peaks_before) + peak_after_width = get_main_width(peaks_after) + + widths = { + "trough_width_us": trough_width * samples_to_us if not np.isnan(trough_width) else np.nan, + "peak_before_width_us": peak_before_width * samples_to_us if not np.isnan(peak_before_width) else np.nan, + "peak_after_width_us": peak_after_width * samples_to_us if not np.isnan(peak_after_width) else np.nan, + } + + return widths + + +######################################################################################### +# Single-channel metrics +def get_peak_to_valley(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs) -> float: + """ + Return the peak to valley duration in seconds of input waveforms. Parameters ---------- @@ -75,13 +544,17 @@ def get_peak_trough_ratio(template_single, sampling_frequency=None, trough_idx=N Returns ------- - ptratio: float - The peak to trough ratio + ptv: float + The peak to valley duration in seconds """ if trough_idx is None or peak_idx is None: - trough_idx, peak_idx = get_trough_and_peak_idx(template_single) - ptratio = template_single[peak_idx] / template_single[trough_idx] - return ptratio + troughs, _, peaks_after = get_trough_and_peak_idx(template_single) + trough_idx = troughs["main_loc"] + peak_idx = peaks_after["main_loc"] + if trough_idx is None or peak_idx is None: + return np.nan + ptv = (peak_idx - trough_idx) / sampling_frequency + return ptv def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs) -> float: @@ -105,9 +578,11 @@ def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_id The half width in seconds """ if trough_idx is None or peak_idx is None: - trough_idx, peak_idx = get_trough_and_peak_idx(template_single) + troughs, _, peaks_after = get_trough_and_peak_idx(template_single) + trough_idx = troughs["main_loc"] + peak_idx = peaks_after["main_loc"] - if peak_idx == 0: + if peak_idx is None or peak_idx == 0: return np.nan trough_val = template_single[trough_idx] @@ -156,11 +631,12 @@ def get_repolarization_slope(template_single, sampling_frequency, trough_idx=Non The repolarization slope """ if trough_idx is None: - trough_idx, _ = get_trough_and_peak_idx(template_single) + troughs, _, _ = get_trough_and_peak_idx(template_single) + trough_idx = troughs["main_loc"] times = np.arange(template_single.shape[0]) / sampling_frequency - if trough_idx == 0: + if trough_idx is None or trough_idx == 0: return np.nan (rtrn_idx,) = np.nonzero(template_single[trough_idx:] >= 0) @@ -209,11 +685,12 @@ def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwa assert "recovery_window_ms" in kwargs, "recovery_window_ms must be given as kwarg" recovery_window_ms = kwargs["recovery_window_ms"] if peak_idx is None: - _, peak_idx = get_trough_and_peak_idx(template_single) + _, _, peaks_after = get_trough_and_peak_idx(template_single) + peak_idx = peaks_after["main_loc"] times = np.arange(template_single.shape[0]) / sampling_frequency - if peak_idx == 0: + if peak_idx is None or peak_idx == 0: return np.nan max_idx = int(peak_idx + ((recovery_window_ms / 1000) * sampling_frequency)) max_idx = np.min([max_idx, template_single.shape[0]]) @@ -222,9 +699,12 @@ def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwa return res.slope -def get_number_of_peaks(template_single, sampling_frequency, **kwargs): +def get_number_of_peaks(template_single, sampling_frequency, troughs, peaks_before, peaks_after, **kwargs): """ - Count the total number of peaks (positive + negative) in the template. + Count the total number of peaks (positive) and troughs (negative) in the template. + + Uses the pre-computed peak/trough detection from get_trough_and_peak_idx which + applies smoothing for more robust detection. Parameters ---------- @@ -232,28 +712,28 @@ def get_number_of_peaks(template_single, sampling_frequency, **kwargs): The 1D template waveform sampling_frequency : float The sampling frequency of the template - **kwargs: Required kwargs: - - peak_relative_threshold: the relative threshold to detect positive and negative peaks - - peak_width_ms: the width in samples to detect peaks + troughs : dict + Trough detection results from get_trough_and_peak_idx + peaks_before : dict + Peak before trough results from get_trough_and_peak_idx + peaks_after : dict + Peak after trough results from get_trough_and_peak_idx Returns ------- - number_of_peaks: int - the total number of peaks (positive + negative) - """ - from scipy.signal import find_peaks - - assert "peak_relative_threshold" in kwargs, "peak_relative_threshold must be given as kwarg" - assert "peak_width_ms" in kwargs, "peak_width_ms must be given as kwarg" - peak_relative_threshold = kwargs["peak_relative_threshold"] - peak_width_ms = kwargs["peak_width_ms"] - max_value = np.max(np.abs(template_single)) - peak_width_samples = int(peak_width_ms / 1000 * sampling_frequency) - - pos_peaks = find_peaks(template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) - neg_peaks = find_peaks(-template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) - num_positive = len(pos_peaks[0]) - num_negative = len(neg_peaks[0]) + num_positive_peaks : int + The number of positive peaks (peaks_before + peaks_after) + num_negative_peaks : int + The number of negative peaks (troughs) + """ + # Count peaks (positive) from peaks_before and peaks_after + num_peaks_before = len(peaks_before["indices"]) + num_peaks_after = len(peaks_after["indices"]) + num_positive = num_peaks_before + num_peaks_after + + # Count troughs (negative) + num_negative = len(troughs["indices"]) + return num_positive, num_negative @@ -293,7 +773,7 @@ def fit_velocity(peak_times, channel_dist): from sklearn.linear_model import TheilSenRegressor - theil = TheilSenRegressor() + theil = TheilSenRegressor(max_iter=1000) theil.fit(peak_times.reshape(-1, 1), channel_dist) slope = theil.coef_[0] intercept = theil.intercept_ @@ -376,7 +856,11 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs) def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs): """ - Compute the exponential decay of the template amplitude over distance in units um/s. + Compute the spatial decay of the template amplitude over distance. + + Can fit either an exponential decay (with offset) or a linear decay model. Channels are first + filtered by x-distance tolerance from the max channel, then the closest channels + in y-distance are used for fitting. Parameters ---------- @@ -387,13 +871,18 @@ def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs sampling_frequency : float The sampling frequency of the template **kwargs: Required kwargs: - - peak_function: the function to use to compute the peak amplitude for the exp decay ("ptp" or "min") - - min_r2: the minimum r2 to accept the exp decay fit + - peak_function: the function to use to compute the peak amplitude ("ptp" or "min") + - min_r2: the minimum r2 to accept the fit + - linear_fit: bool, if True use linear fit, otherwise exponential fit + - channel_tolerance: max x-distance (um) from max channel to include channels + - min_channels_for_fit: minimum number of valid channels required for fitting + - num_channels_for_fit: number of closest channels to use for fitting + - normalize_decay: bool, if True normalize amplitudes to max before fitting Returns ------- exp_decay_value : float - The exponential decay of the template amplitude + The spatial decay slope (decay constant for exp fit, negative slope for linear fit) """ from scipy.optimize import curve_fit from sklearn.metrics import r2_score @@ -401,41 +890,117 @@ def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs def exp_decay(x, decay, amp0, offset): return amp0 * np.exp(-decay * x) + offset + def linear_fit_func(x, a, b): + return a * x + b + + # Extract parameters assert "peak_function" in kwargs, "peak_function must be given as kwarg" peak_function = kwargs["peak_function"] assert "min_r2" in kwargs, "min_r2 must be given as kwarg" min_r2 = kwargs["min_r2"] - # exp decay fit + + use_linear_fit = kwargs.get("linear_fit", False) + channel_tolerance = kwargs.get("channel_tolerance", None) + normalize_decay = kwargs.get("normalize_decay", False) + + # Set defaults based on fit type if not specified + min_channels_for_fit = kwargs.get("min_channels_for_fit") + if min_channels_for_fit is None: + min_channels_for_fit = 5 if use_linear_fit else 8 + + num_channels_for_fit = kwargs.get("num_channels_for_fit") + if num_channels_for_fit is None: + num_channels_for_fit = 6 if use_linear_fit else 10 + + # Compute peak amplitudes per channel if peak_function == "ptp": fun = np.ptp elif peak_function == "min": fun = np.min + else: + fun = np.ptp + peak_amplitudes = np.abs(fun(template, axis=0)) - max_channel_location = channel_locations[np.argmax(peak_amplitudes)] - channel_distances = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations]) - distances_sort_indices = np.argsort(channel_distances) + max_channel_idx = np.argmax(peak_amplitudes) + max_channel_location = channel_locations[max_channel_idx] - # longdouble is float128 when the platform supports it, otherwise it is float64 - channel_distances_sorted = channel_distances[distances_sort_indices].astype(np.longdouble) - peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices].astype(np.longdouble) + # Channel selection based on tolerance (new bombcell-style) or use all channels (old style) + if channel_tolerance is not None: + # Calculate x-distances from max channel + x_dist = np.abs(channel_locations[:, 0] - max_channel_location[0]) - try: - amp0 = peak_amplitudes_sorted[0] - offset0 = np.min(peak_amplitudes_sorted) - - popt, _ = curve_fit( - exp_decay, - channel_distances_sorted, - peak_amplitudes_sorted, - bounds=([1e-5, amp0 - 0.5 * amp0, 0], [2, amp0 + 0.5 * amp0, 2 * offset0]), - p0=[1e-3, peak_amplitudes_sorted[0], offset0], + # Find channels within x-distance tolerance + valid_x_channels = np.argwhere(x_dist <= channel_tolerance).flatten() + + if len(valid_x_channels) < min_channels_for_fit: + return np.nan + + # Calculate y-distances for channel selection + y_dist = np.abs(channel_locations[:, 1] - max_channel_location[1]) + + # Set y distances to max for channels outside x tolerance (so they won't be selected) + y_dist_masked = y_dist.copy() + y_dist_masked[~np.isin(np.arange(len(y_dist)), valid_x_channels)] = y_dist.max() + 1 + + # Select the closest channels in y-distance + use_these_channels = np.argsort(y_dist_masked)[:num_channels_for_fit] + + # Calculate distances from max channel for selected channels + channel_distances = np.sqrt( + np.sum(np.square(channel_locations[use_these_channels] - max_channel_location), axis=1) ) - r2 = r2_score(peak_amplitudes_sorted, exp_decay(channel_distances_sorted, *popt)) - exp_decay_value = popt[0] + + # Get amplitudes for selected channels + spatial_decay_points = np.max(np.abs(template[:, use_these_channels]), axis=0) + + # Sort by distance + sort_idx = np.argsort(channel_distances) + channel_distances_sorted = channel_distances[sort_idx] + peak_amplitudes_sorted = spatial_decay_points[sort_idx] + + # Normalize if requested + if normalize_decay: + peak_amplitudes_sorted = peak_amplitudes_sorted / np.max(peak_amplitudes_sorted) + + # Ensure float64 for numerical stability + channel_distances_sorted = np.float64(channel_distances_sorted) + peak_amplitudes_sorted = np.float64(peak_amplitudes_sorted) + + else: + # Old style: use all channels sorted by distance + channel_distances = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations]) + distances_sort_indices = np.argsort(channel_distances) + + # longdouble is float128 when the platform supports it, otherwise it is float64 + channel_distances_sorted = channel_distances[distances_sort_indices].astype(np.longdouble) + peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices].astype(np.longdouble) + + try: + if use_linear_fit: + # Linear fit: y = a*x + b + popt, _ = curve_fit(linear_fit_func, channel_distances_sorted, peak_amplitudes_sorted) + predicted = linear_fit_func(channel_distances_sorted, *popt) + r2 = r2_score(peak_amplitudes_sorted, predicted) + exp_decay_value = -popt[0] # Negative of slope + else: + # Exponential fit with offset: y = amp0 * exp(-decay * x) + offset + amp0 = peak_amplitudes_sorted[0] + offset0 = np.min(peak_amplitudes_sorted) + + popt, _ = curve_fit( + exp_decay, + channel_distances_sorted, + peak_amplitudes_sorted, + bounds=([1e-5, amp0 - 0.5 * amp0, 0], [2, amp0 + 0.5 * amp0, 2 * offset0]), + p0=[1e-3, peak_amplitudes_sorted[0], offset0], + ) + r2 = r2_score(peak_amplitudes_sorted, exp_decay(channel_distances_sorted, *popt)) + exp_decay_value = popt[0] if r2 < min_r2: exp_decay_value = np.nan - except: + + except Exception: exp_decay_value = np.nan return exp_decay_value @@ -512,17 +1077,17 @@ def single_channel_metric(unit_function, sorting_analyzer, unit_ids, tmp_data, * return result -class PeakToValley(BaseMetric): - metric_name = "peak_to_valley" +class PeakToTroughDuration(BaseMetric): + metric_name = "peak_to_trough_duration" metric_params = {} - metric_columns = {"peak_to_valley": float} + metric_columns = {"peak_to_trough_duration": float} metric_descriptions = { - "peak_to_valley": "Duration in s between the trough (minimum) and the peak (maximum) of the spike waveform." + "peak_to_trough_duration": "Duration in seconds between the trough (minimum) and the peak (maximum) of the spike waveform." } needs_tmp_data = True @staticmethod - def _peak_to_valley_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + def _peak_to_trough_duration_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): return single_channel_metric( unit_function=get_peak_to_valley, sorting_analyzer=sorting_analyzer, @@ -531,29 +1096,7 @@ def _peak_to_valley_metric_function(sorting_analyzer, unit_ids, tmp_data, **metr **metric_params, ) - metric_function = _peak_to_valley_metric_function - - -class PeakToTroughRatio(BaseMetric): - metric_name = "peak_trough_ratio" - metric_params = {} - metric_columns = {"peak_trough_ratio": float} - metric_descriptions = { - "peak_trough_ratio": "Ratio of the amplitude of the peak (maximum) to the trough (minimum) of the spike waveform." - } - needs_tmp_data = True - - @staticmethod - def _peak_to_trough_ratio_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): - return single_channel_metric( - unit_function=get_peak_trough_ratio, - sorting_analyzer=sorting_analyzer, - unit_ids=unit_ids, - tmp_data=tmp_data, - **metric_params, - ) - - metric_function = _peak_to_trough_ratio_metric_function + metric_function = _peak_to_trough_duration_metric_function class HalfWidth(BaseMetric): @@ -626,11 +1169,21 @@ def _number_of_peaks_metric_function(sorting_analyzer, unit_ids, tmp_data, **met num_peaks_result = namedtuple("NumberOfPeaksResult", ["num_positive_peaks", "num_negative_peaks"]) num_positive_peaks_dict = {} num_negative_peaks_dict = {} - sampling_frequency = sorting_analyzer.sampling_frequency + sampling_frequency = tmp_data["sampling_frequency"] templates_single = tmp_data["templates_single"] + troughs_info = tmp_data["troughs_info"] + peaks_before_info = tmp_data["peaks_before_info"] + peaks_after_info = tmp_data["peaks_after_info"] for unit_index, unit_id in enumerate(unit_ids): template_single = templates_single[unit_index] - num_positive, num_negative = get_number_of_peaks(template_single, sampling_frequency, **metric_params) + num_positive, num_negative = get_number_of_peaks( + template_single, + sampling_frequency, + troughs_info[unit_id], + peaks_before_info[unit_id], + peaks_after_info[unit_id], + **metric_params, + ) num_positive_peaks_dict[unit_id] = num_positive num_negative_peaks_dict[unit_id] = num_negative return num_peaks_result(num_positive_peaks=num_positive_peaks_dict, num_negative_peaks=num_negative_peaks_dict) @@ -639,22 +1192,184 @@ def _number_of_peaks_metric_function(sorting_analyzer, unit_ids, tmp_data, **met class NumberOfPeaks(BaseMetric): metric_name = "number_of_peaks" metric_function = _number_of_peaks_metric_function - metric_params = {"peak_relative_threshold": 0.2, "peak_width_ms": 0.1} + metric_params = {} metric_columns = {"num_positive_peaks": int, "num_negative_peaks": int} metric_descriptions = { "num_positive_peaks": "Number of positive peaks in the template", - "num_negative_peaks": "Number of negative peaks in the template", + "num_negative_peaks": "Number of negative peaks (troughs) in the template", + } + needs_tmp_data = True + + +def _waveform_duration_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + result = {} + templates_single = tmp_data["templates_single"] + troughs_info = tmp_data["troughs_info"] + peaks_before_info = tmp_data["peaks_before_info"] + peaks_after_info = tmp_data["peaks_after_info"] + sampling_frequency = tmp_data["sampling_frequency"] + for unit_index, unit_id in enumerate(unit_ids): + template_single = templates_single[unit_index] + value = get_waveform_duration( + template_single, + sampling_frequency, + troughs_info[unit_id], + peaks_before_info[unit_id], + peaks_after_info[unit_id], + **metric_params, + ) + result[unit_id] = value + return result + + +class WaveformDuration(BaseMetric): + metric_name = "waveform_duration" + metric_function = _waveform_duration_metric_function + metric_params = {} + metric_columns = {"waveform_duration": float} + metric_descriptions = { + "waveform_duration": "Waveform duration in microseconds from main extremum to next extremum." + } + needs_tmp_data = True + + +def _waveform_ratios_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + waveform_ratios_result = namedtuple( + "WaveformRatiosResult", + [ + "peak_before_to_trough_ratio", + "peak_after_to_trough_ratio", + "peak_before_to_peak_after_ratio", + "main_peak_to_trough_ratio", + ], + ) + peak_before_to_trough = {} + peak_after_to_trough = {} + peak_before_to_peak_after = {} + main_peak_to_trough = {} + templates_single = tmp_data["templates_single"] + troughs_info = tmp_data["troughs_info"] + peaks_before_info = tmp_data["peaks_before_info"] + peaks_after_info = tmp_data["peaks_after_info"] + for unit_index, unit_id in enumerate(unit_ids): + template_single = templates_single[unit_index] + ratios = get_waveform_ratios( + template_single, + troughs_info[unit_id], + peaks_before_info[unit_id], + peaks_after_info[unit_id], + **metric_params, + ) + peak_before_to_trough[unit_id] = ratios["peak_before_to_trough_ratio"] + peak_after_to_trough[unit_id] = ratios["peak_after_to_trough_ratio"] + peak_before_to_peak_after[unit_id] = ratios["peak_before_to_peak_after_ratio"] + main_peak_to_trough[unit_id] = ratios["main_peak_to_trough_ratio"] + return waveform_ratios_result( + peak_before_to_trough_ratio=peak_before_to_trough, + peak_after_to_trough_ratio=peak_after_to_trough, + peak_before_to_peak_after_ratio=peak_before_to_peak_after, + main_peak_to_trough_ratio=main_peak_to_trough, + ) + + +class WaveformRatios(BaseMetric): + metric_name = "waveform_ratios" + metric_function = _waveform_ratios_metric_function + metric_params = {} + metric_columns = { + "peak_before_to_trough_ratio": float, + "peak_after_to_trough_ratio": float, + "peak_before_to_peak_after_ratio": float, + "main_peak_to_trough_ratio": float, + } + metric_descriptions = { + "peak_before_to_trough_ratio": "Ratio of peak before amplitude to trough amplitude", + "peak_after_to_trough_ratio": "Ratio of peak after amplitude to trough amplitude", + "peak_before_to_peak_after_ratio": "Ratio of peak before amplitude to peak after amplitude", + "main_peak_to_trough_ratio": "Ratio of main peak amplitude to trough amplitude", + } + needs_tmp_data = True + + +def _waveform_widths_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + waveform_widths_result = namedtuple( + "WaveformWidthsResult", ["trough_width", "peak_before_width", "peak_after_width"] + ) + trough_width_dict = {} + peak_before_width_dict = {} + peak_after_width_dict = {} + templates_single = tmp_data["templates_single"] + troughs_info = tmp_data["troughs_info"] + peaks_before_info = tmp_data["peaks_before_info"] + peaks_after_info = tmp_data["peaks_after_info"] + sampling_frequency = tmp_data["sampling_frequency"] + for unit_index, unit_id in enumerate(unit_ids): + template_single = templates_single[unit_index] + widths = get_waveform_widths( + template_single, + sampling_frequency, + troughs_info[unit_id], + peaks_before_info[unit_id], + peaks_after_info[unit_id], + **metric_params, + ) + trough_width_dict[unit_id] = widths["trough_width_us"] + peak_before_width_dict[unit_id] = widths["peak_before_width_us"] + peak_after_width_dict[unit_id] = widths["peak_after_width_us"] + return waveform_widths_result( + trough_width=trough_width_dict, peak_before_width=peak_before_width_dict, peak_after_width=peak_after_width_dict + ) + + +class WaveformWidths(BaseMetric): + metric_name = "waveform_widths" + metric_function = _waveform_widths_metric_function + metric_params = {} + metric_columns = { + "trough_width": float, + "peak_before_width": float, + "peak_after_width": float, + } + metric_descriptions = { + "trough_width": "Width of the main trough in microseconds", + "peak_before_width": "Width of the main peak before trough in microseconds", + "peak_after_width": "Width of the main peak after trough in microseconds", + } + needs_tmp_data = True + + +def _waveform_baseline_flatness_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + result = {} + templates_single = tmp_data["templates_single"] + sampling_frequency = tmp_data["sampling_frequency"] + for unit_index, unit_id in enumerate(unit_ids): + template_single = templates_single[unit_index] + value = get_waveform_baseline_flatness(template_single, sampling_frequency, **metric_params) + result[unit_id] = value + return result + + +class WaveformBaselineFlatness(BaseMetric): + metric_name = "waveform_baseline_flatness" + metric_function = _waveform_baseline_flatness_metric_function + metric_params = {"baseline_window_ms": (0.0, 0.5)} + metric_columns = {"waveform_baseline_flatness": float} + metric_descriptions = { + "waveform_baseline_flatness": "Ratio of max baseline amplitude to max waveform amplitude. Lower = flatter baseline." } needs_tmp_data = True single_channel_metrics = [ - PeakToValley, - PeakToTroughRatio, + PeakToTroughDuration, HalfWidth, RepolarizationSlope, RecoverySlope, NumberOfPeaks, + WaveformDuration, + WaveformRatios, + WaveformWidths, + WaveformBaselineFlatness, ] @@ -707,10 +1422,21 @@ def multi_channel_metric(unit_function, sorting_analyzer, unit_ids, tmp_data, ** class ExpDecay(BaseMetric): metric_name = "exp_decay" - metric_params = {"peak_function": "ptp", "min_r2": 0.2} + metric_params = { + "peak_function": "ptp", + "min_r2": 0.2, + "linear_fit": False, + "channel_tolerance": None, # None uses old style (all channels), set to e.g. 33 for bombcell-style + "min_channels_for_fit": None, # None means use default based on linear_fit (5 for linear, 8 for exp) + "num_channels_for_fit": None, # None means use default based on linear_fit (6 for linear, 10 for exp) + "normalize_decay": False, + } metric_columns = {"exp_decay": float} metric_descriptions = { - "exp_decay": ("Exponential decay of the template amplitude over distance from the extremum channel (1/um).") + "exp_decay": ( + "Spatial decay of the template amplitude over distance from the extremum channel (1/um). " + "Uses exponential or linear fit based on linear_fit parameter." + ) } needs_tmp_data = True diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index 83a9048a64..f00e870c30 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -9,6 +9,7 @@ import numpy as np import warnings from copy import deepcopy +from scipy.signal import find_peaks from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseMetricExtension @@ -46,8 +47,8 @@ def get_template_metric_names(): class ComputeTemplateMetrics(BaseMetricExtension): """ Compute template metrics including: - * peak_to_valley - * peak_trough_ratio + * peak_to_trough_duration + * peak_to_trough_ratio * halfwidth * repolarization_slope * recovery_slope @@ -125,6 +126,16 @@ def _handle_backward_compatibility_on_load(self): self.params["metric_names"].remove("velocity_below") if "velocity_fits" not in self.params["metric_names"]: self.params["metric_names"].append("velocity_fits") + # peak to valley -> peak_to_trough_duration + if "peak_to_valley" in self.params["metric_names"]: + self.params["metric_names"].remove("peak_to_valley") + if "peak_to_trough_duration" not in self.params["metric_names"]: + self.params["metric_names"].append("peak_to_trough_duration") + # peak to trough ratio -> main peak to trough ratio + if "peak_to_trough_ratio" in self.params["metric_names"]: + self.params["metric_names"].remove("peak_to_trough_ratio") + if "main_peak_to_trough_ratio" not in self.params["metric_names"]: + self.params["metric_names"].append("main_peak_to_trough_ratio") def _set_params( self, @@ -137,6 +148,12 @@ def _set_params( upsampling_factor=10, include_multi_channel_metrics=False, depth_direction="y", + min_thresh_detect_peaks_troughs=0.4, + smooth=True, + smooth_method="savgol", + smooth_window_frac=0.1, + smooth_polyorder=3, + svd_n_components=3, ): # Auto-detect if multi-channel metrics should be included based on number of channels num_channels = self.sorting_analyzer.get_num_channels() @@ -165,6 +182,10 @@ def _set_params( upsampling_factor=upsampling_factor, include_multi_channel_metrics=include_multi_channel_metrics, depth_direction=depth_direction, + min_thresh_detect_peaks_troughs=min_thresh_detect_peaks_troughs, + smooth=smooth, + smooth_window_frac=smooth_window_frac, + smooth_polyorder=smooth_polyorder, ) def _prepare_data(self, sorting_analyzer, unit_ids): @@ -196,6 +217,9 @@ def _prepare_data(self, sorting_analyzer, unit_ids): templates_single = [] troughs = {} peaks = {} + troughs_info = {} + peaks_before_info = {} + peaks_after_info = {} templates_multi = [] channel_locations_multi = [] for unit_id in unit_ids: @@ -209,11 +233,22 @@ def _prepare_data(self, sorting_analyzer, unit_ids): else: template_upsampled = template_single sampling_frequency_up = sampling_frequency - trough_idx, peak_idx = get_trough_and_peak_idx(template_upsampled) + troughs_dict, peaks_before_dict, peaks_after_dict = get_trough_and_peak_idx( + template_upsampled, + min_thresh_detect_peaks_troughs=self.params["min_thresh_detect_peaks_troughs"], + smooth=self.params["smooth"], + smooth_window_frac=self.params["smooth_window_frac"], + smooth_polyorder=self.params["smooth_polyorder"], + ) templates_single.append(template_upsampled) - troughs[unit_id] = trough_idx - peaks[unit_id] = peak_idx + # Store main locations for backward compatibility + troughs[unit_id] = troughs_dict["main_loc"] + peaks[unit_id] = peaks_after_dict["main_loc"] + # Store full dicts for new metrics + troughs_info[unit_id] = troughs_dict + peaks_before_info[unit_id] = peaks_before_dict + peaks_after_info[unit_id] = peaks_after_dict if include_multi_channel_metrics: if sorting_analyzer.is_sparse(): @@ -238,6 +273,9 @@ def _prepare_data(self, sorting_analyzer, unit_ids): tmp_data["troughs"] = troughs tmp_data["peaks"] = peaks + tmp_data["troughs_info"] = troughs_info + tmp_data["peaks_before_info"] = peaks_before_info + tmp_data["peaks_after_info"] = peaks_after_info tmp_data["templates_single"] = np.array(templates_single) if include_multi_channel_metrics: diff --git a/src/spikeinterface/widgets/unit_labelling.py b/src/spikeinterface/widgets/unit_labelling.py new file mode 100644 index 0000000000..0c01b7f528 --- /dev/null +++ b/src/spikeinterface/widgets/unit_labelling.py @@ -0,0 +1,560 @@ +"""Widgets for visualizing unit labelling results.""" + +from __future__ import annotations + +import numpy as np +from typing import Optional + +from .base import BaseWidget, to_attr + + +def _combine_metrics(quality_metrics, template_metrics): + """Combine quality_metrics and template_metrics into a single DataFrame.""" + if quality_metrics is None and template_metrics is None: + return None + if quality_metrics is None: + return template_metrics + if template_metrics is None: + return quality_metrics + return quality_metrics.join(template_metrics, how="outer") + + +class LabellingHistogramsWidget(BaseWidget): + """Plot histograms of quality metrics with threshold lines.""" + + def __init__( + self, + quality_metrics=None, + template_metrics=None, + thresholds: Optional[dict] = None, + metrics_to_plot: Optional[list] = None, + backend=None, + **backend_kwargs, + ): + from spikeinterface.curation import bombcell_get_default_thresholds + + combined_metrics = _combine_metrics(quality_metrics, template_metrics) + if combined_metrics is None: + raise ValueError("At least one of quality_metrics or template_metrics must be provided") + + if thresholds is None: + thresholds = bombcell_get_default_thresholds() + if metrics_to_plot is None: + metrics_to_plot = [m for m in thresholds.keys() if m in combined_metrics.columns] + + plot_data = dict( + quality_metrics=combined_metrics, + thresholds=thresholds, + metrics_to_plot=metrics_to_plot, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + + dp = to_attr(data_plot) + quality_metrics = dp.quality_metrics + thresholds = dp.thresholds + metrics_to_plot = dp.metrics_to_plot + + n_metrics = len(metrics_to_plot) + if n_metrics == 0: + print("No metrics to plot") + return + + n_cols = min(4, n_metrics) + n_rows = int(np.ceil(n_metrics / n_cols)) + fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows)) + if n_metrics == 1: + axes = np.array([[axes]]) + elif n_rows == 1: + axes = axes.reshape(1, -1) + elif n_cols == 1: + axes = axes.reshape(-1, 1) + + colors = plt.cm.tab10(np.linspace(0, 1, 10)) + absolute_value_metrics = ["amplitude_median"] + + for idx, metric_name in enumerate(metrics_to_plot): + row, col = idx // n_cols, idx % n_cols + ax = axes[row, col] + + values = quality_metrics[metric_name].values + if metric_name in absolute_value_metrics: + values = np.abs(values) + values = values[~np.isnan(values) & ~np.isinf(values)] + + if len(values) == 0: + ax.set_title(f"{metric_name}\n(no valid data)") + continue + + ax.hist(values, bins=30, color=colors[idx % 10], alpha=0.7, edgecolor="black", density=True) + + thresh = thresholds.get(metric_name, {}) + has_thresh = False + if not np.isnan(thresh.get("min", np.nan)): + ax.axvline(thresh["min"], color="red", ls="--", lw=2, label=f"min={thresh['min']:.2g}") + has_thresh = True + if not np.isnan(thresh.get("max", np.nan)): + ax.axvline(thresh["max"], color="blue", ls="--", lw=2, label=f"max={thresh['max']:.2g}") + has_thresh = True + + ax.set_xlabel(metric_name) + ax.set_ylabel("Density") + if has_thresh: + ax.legend(fontsize=8, loc="upper right") + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + for idx in range(len(metrics_to_plot), n_rows * n_cols): + axes[idx // n_cols, idx % n_cols].set_visible(False) + + plt.tight_layout() + self.figure = fig + self.axes = axes + + +class WaveformOverlayWidget(BaseWidget): + """Plot overlaid waveforms grouped by unit label type.""" + + def __init__( + self, + sorting_analyzer, + unit_type: np.ndarray, + unit_type_string: np.ndarray, + split_non_somatic: bool = False, + backend=None, + **backend_kwargs, + ): + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) + plot_data = dict( + sorting_analyzer=sorting_analyzer, + unit_type=unit_type, + unit_type_string=unit_type_string, + split_non_somatic=split_non_somatic, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + + dp = to_attr(data_plot) + sorting_analyzer = dp.sorting_analyzer + unit_type = dp.unit_type + split_non_somatic = dp.split_non_somatic + + if not sorting_analyzer.has_extension("templates"): + fig, ax = plt.subplots(1, 1, figsize=(8, 6)) + ax.text( + 0.5, + 0.5, + "Templates extension not computed.\nRun: analyzer.compute('templates')", + ha="center", + va="center", + fontsize=12, + ) + ax.axis("off") + self.figure = fig + self.axes = ax + return + + templates_ext = sorting_analyzer.get_extension("templates") + templates = templates_ext.get_templates(operator="average") + + if split_non_somatic: + labels = {0: "NOISE", 1: "GOOD", 2: "MUA", 3: "NON_SOMA_GOOD", 4: "NON_SOMA_MUA"} + n_plots, nrows, ncols = 5, 2, 3 + else: + labels = {0: "NOISE", 1: "GOOD", 2: "MUA", 3: "NON_SOMA"} + n_plots, nrows, ncols = 4, 2, 2 + + fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 4 * nrows)) + axes_flat = axes.flatten() + + for plot_idx in range(n_plots): + ax = axes_flat[plot_idx] + type_label = labels.get(plot_idx, "") + mask = unit_type == plot_idx + n_units = np.sum(mask) + + if n_units > 0: + unit_indices = np.where(mask)[0] + alpha = max(0.05, min(0.3, 10 / n_units)) + for unit_idx in unit_indices: + template = templates[unit_idx] + best_chan = np.argmax(np.max(np.abs(template), axis=0)) + ax.plot(template[:, best_chan], color="black", alpha=alpha, linewidth=0.5) + ax.set_title(f"{type_label} (n={n_units})") + else: + ax.set_title(f"{type_label} (n=0)") + ax.text(0.5, 0.5, "No units", ha="center", va="center", transform=ax.transAxes) + + for spine in ax.spines.values(): + spine.set_visible(False) + ax.set_xticks([]) + ax.set_yticks([]) + + for idx in range(n_plots, nrows * ncols): + axes_flat[idx].set_visible(False) + + plt.tight_layout() + self.figure = fig + self.axes = axes + + +class UpsetPlotWidget(BaseWidget): + """ + Plot UpSet plots showing which metrics fail together for each unit type. + + Requires `upsetplot` package. Each unit type shows relevant metrics: + NOISE -> waveform metrics, MUA -> spike quality metrics, NON_SOMA -> non-somatic metrics. + """ + + WAVEFORM_METRICS = [ + "num_positive_peaks", + "num_negative_peaks", + "peak_to_trough_duration", + "waveform_baseline_flatness", + "peak_after_to_trough_ratio", + "exp_decay", + ] + SPIKE_QUALITY_METRICS = [ + "amplitude_median", + "snr_bombcell", + "amplitude_cutoff", + "num_spikes", + "rp_contamination", + "presence_ratio", + "drift_ptp", + ] + NON_SOMATIC_METRICS = [ + "peak_before_to_trough_ratio", + "peak_before_width", + "trough_width", + "peak_before_to_peak_after_ratio", + "main_peak_to_trough_ratio", + ] + + def __init__( + self, + unit_type: np.ndarray, + unit_type_string: np.ndarray, + quality_metrics=None, + template_metrics=None, + thresholds: Optional[dict] = None, + unit_types_to_plot: Optional[list] = None, + split_non_somatic: bool = False, + min_subset_size: int = 1, + backend=None, + **backend_kwargs, + ): + from spikeinterface.curation import bombcell_get_default_thresholds + + combined_metrics = _combine_metrics(quality_metrics, template_metrics) + if combined_metrics is None: + raise ValueError("At least one of quality_metrics or template_metrics must be provided") + + if thresholds is None: + thresholds = bombcell_get_default_thresholds() + if unit_types_to_plot is None: + if split_non_somatic: + unit_types_to_plot = ["NOISE", "MUA", "NON_SOMA_GOOD", "NON_SOMA_MUA"] + else: + unit_types_to_plot = ["NOISE", "MUA", "NON_SOMA"] + + plot_data = dict( + quality_metrics=combined_metrics, + unit_type=unit_type, + unit_type_string=unit_type_string, + thresholds=thresholds, + unit_types_to_plot=unit_types_to_plot, + min_subset_size=min_subset_size, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def _get_metrics_for_unit_type(self, unit_type_label): + if unit_type_label == "NOISE": + return self.WAVEFORM_METRICS + elif unit_type_label == "MUA": + return self.SPIKE_QUALITY_METRICS + elif unit_type_label in ("NON_SOMA", "NON_SOMA_GOOD", "NON_SOMA_MUA"): + return self.NON_SOMATIC_METRICS + return None + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import warnings + import matplotlib.pyplot as plt + import pandas as pd + + dp = to_attr(data_plot) + quality_metrics = dp.quality_metrics + unit_type_string = dp.unit_type_string + thresholds = dp.thresholds + unit_types_to_plot = dp.unit_types_to_plot + min_subset_size = dp.min_subset_size + + try: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=FutureWarning, module="upsetplot") + from upsetplot import UpSet, from_memberships + except ImportError: + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + ax.text( + 0.5, + 0.5, + "UpSet plots require 'upsetplot' package.\n\npip install upsetplot", + ha="center", + va="center", + fontsize=14, + family="monospace", + bbox=dict(boxstyle="round", facecolor="lightyellow", edgecolor="orange"), + ) + ax.axis("off") + ax.set_title("UpSet Plot - Package Not Installed", fontsize=16) + self.figure = fig + self.axes = ax + self.figures = [fig] + return + + failure_table = self._build_failure_table(quality_metrics, thresholds) + figures = [] + axes_list = [] + + for unit_type_label in unit_types_to_plot: + mask = unit_type_string == unit_type_label + n_units = np.sum(mask) + if n_units == 0: + continue + + relevant_metrics = self._get_metrics_for_unit_type(unit_type_label) + if relevant_metrics is not None: + available_metrics = [m for m in relevant_metrics if m in failure_table.columns] + if len(available_metrics) == 0: + continue + unit_failure_table = failure_table[available_metrics] + else: + unit_failure_table = failure_table + + unit_failures = unit_failure_table.loc[mask] + memberships = [] + for idx in unit_failures.index: + failed = unit_failures.columns[unit_failures.loc[idx]].tolist() + if failed: + memberships.append(failed) + + if not memberships: + continue + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=FutureWarning, module="upsetplot") + upset_data = from_memberships(memberships) + upset_data = upset_data[upset_data >= min_subset_size] + if len(upset_data) == 0: + continue + + fig = plt.figure(figsize=(12, 6)) + UpSet( + upset_data, + subset_size="count", + show_counts=True, + sort_by="cardinality", + sort_categories_by="cardinality", + ).plot(fig=fig) + fig.suptitle(f"{unit_type_label} (n={n_units})", fontsize=14, y=1.02) + figures.append(fig) + axes_list.append(fig.axes) + + if not figures: + fig, ax = plt.subplots(1, 1, figsize=(8, 6)) + ax.text(0.5, 0.5, "No units found or no metric failures detected.", ha="center", va="center", fontsize=12) + ax.axis("off") + figures = [fig] + axes_list = [ax] + + self.figures = figures + self.figure = figures[0] if figures else None + self.axes = axes_list + + def _build_failure_table(self, quality_metrics, thresholds): + import pandas as pd + + absolute_value_metrics = ["amplitude_median"] + failure_data = {} + + for metric_name, thresh in thresholds.items(): + if metric_name not in quality_metrics.columns: + continue + values = quality_metrics[metric_name].values.copy() + if metric_name in absolute_value_metrics: + values = np.abs(values) + + failed = np.isnan(values) + if not np.isnan(thresh.get("min", np.nan)): + failed |= values < thresh["min"] + if not np.isnan(thresh.get("max", np.nan)): + failed |= values > thresh["max"] + failure_data[metric_name] = failed + + return pd.DataFrame(failure_data, index=quality_metrics.index) + + +# Convenience functions +def plot_labelling_histograms( + quality_metrics=None, template_metrics=None, thresholds=None, metrics_to_plot=None, backend=None, **kwargs +): + """Plot histograms of quality metrics with threshold lines.""" + return LabellingHistogramsWidget( + quality_metrics=quality_metrics, + template_metrics=template_metrics, + thresholds=thresholds, + metrics_to_plot=metrics_to_plot, + backend=backend, + **kwargs, + ) + + +def plot_waveform_overlay( + sorting_analyzer, unit_type, unit_type_string, split_non_somatic=False, backend=None, **kwargs +): + """Plot overlaid waveforms grouped by unit label type.""" + return WaveformOverlayWidget( + sorting_analyzer, unit_type, unit_type_string, split_non_somatic=split_non_somatic, backend=backend, **kwargs + ) + + +def plot_upset( + unit_type, + unit_type_string, + quality_metrics=None, + template_metrics=None, + thresholds=None, + unit_types_to_plot=None, + split_non_somatic=False, + min_subset_size=1, + backend=None, + **kwargs, +): + """Plot UpSet plots showing which metrics fail together for each unit type.""" + return UpsetPlotWidget( + unit_type, + unit_type_string, + quality_metrics=quality_metrics, + template_metrics=template_metrics, + thresholds=thresholds, + unit_types_to_plot=unit_types_to_plot, + split_non_somatic=split_non_somatic, + min_subset_size=min_subset_size, + backend=backend, + **kwargs, + ) + + +def plot_unit_labelling_all( + sorting_analyzer, + unit_type: np.ndarray, + unit_type_string: np.ndarray, + quality_metrics=None, + template_metrics=None, + thresholds: Optional[dict] = None, + split_non_somatic: bool = False, + include_upset: bool = True, + save_folder=None, + backend=None, + **kwargs, +): + """ + Generate all unit labelling plots and optionally save to folder. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object. + unit_type : np.ndarray + Array of unit type codes (0=NOISE, 1=GOOD, 2=MUA, 3=NON_SOMA, etc.). + unit_type_string : np.ndarray + Array of unit type labels as strings. + quality_metrics : pd.DataFrame, optional + Quality metrics DataFrame. If None, loads from sorting_analyzer. + template_metrics : pd.DataFrame, optional + Template metrics DataFrame. If None, loads from sorting_analyzer. + thresholds : dict, optional + Threshold dictionary. If None, uses default thresholds. + split_non_somatic : bool, default: False + Whether to split NON_SOMA into NON_SOMA_GOOD and NON_SOMA_MUA. + include_upset : bool, default: True + Whether to include UpSet plots (requires upsetplot package). + save_folder : str or Path, optional + If provided, saves all plots and CSV results to this folder. + backend : str, optional + Plotting backend. + **kwargs + Additional arguments passed to plot functions. + + Returns + ------- + dict + Dictionary with keys 'histograms', 'waveforms', 'upset' containing widget objects. + """ + from pathlib import Path + from spikeinterface.curation import bombcell_get_default_thresholds, save_labelling_results + + if thresholds is None: + thresholds = bombcell_get_default_thresholds() + + # Load metrics from sorting_analyzer if not provided + if quality_metrics is None and sorting_analyzer.has_extension("quality_metrics"): + quality_metrics = sorting_analyzer.get_extension("quality_metrics").get_data() + if template_metrics is None and sorting_analyzer.has_extension("template_metrics"): + template_metrics = sorting_analyzer.get_extension("template_metrics").get_data() + + combined_metrics = _combine_metrics(quality_metrics, template_metrics) + + results = {} + + # Histograms + if combined_metrics is not None: + results["histograms"] = plot_labelling_histograms( + quality_metrics=quality_metrics, + template_metrics=template_metrics, + thresholds=thresholds, + backend=backend, + **kwargs, + ) + + # Waveform overlay + results["waveforms"] = plot_waveform_overlay( + sorting_analyzer, unit_type, unit_type_string, split_non_somatic=split_non_somatic, backend=backend, **kwargs + ) + + # UpSet plots + if include_upset and combined_metrics is not None: + results["upset"] = plot_upset( + unit_type, + unit_type_string, + quality_metrics=quality_metrics, + template_metrics=template_metrics, + thresholds=thresholds, + split_non_somatic=split_non_somatic, + backend=backend, + **kwargs, + ) + + # Save to folder if requested + if save_folder is not None: + save_folder = Path(save_folder) + save_folder.mkdir(parents=True, exist_ok=True) + + # Save plots + if "histograms" in results and results["histograms"].figure is not None: + results["histograms"].figure.savefig(save_folder / "labelling_histograms.png", dpi=150, bbox_inches="tight") + if "waveforms" in results and results["waveforms"].figure is not None: + results["waveforms"].figure.savefig(save_folder / "waveform_overlay.png", dpi=150, bbox_inches="tight") + if "upset" in results and hasattr(results["upset"], "figures"): + for i, fig in enumerate(results["upset"].figures): + fig.savefig(save_folder / f"upset_plot_{i}.png", dpi=150, bbox_inches="tight") + + # Save CSV results + if combined_metrics is not None: + save_labelling_results(combined_metrics, unit_type, unit_type_string, thresholds, save_folder) + + return results diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 6edba67c96..d8f2f46856 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -37,12 +37,22 @@ from .comparison import AgreementMatrixWidget, ConfusionMatrixWidget from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyAgreementMatrix, StudySummary from .collision import ComparisonCollisionBySimilarityWidget, StudyComparisonCollisionBySimilarityWidget +from .unit_labelling import ( + LabellingHistogramsWidget, + WaveformOverlayWidget, + UpsetPlotWidget, + plot_labelling_histograms, + plot_waveform_overlay, + plot_upset, + plot_unit_labelling_all, +) widget_list = [ AgreementMatrixWidget, AllAmplitudesDistributionsWidget, AmplitudesWidget, AutoCorrelogramsWidget, + LabellingHistogramsWidget, ConfusionMatrixWidget, ComparisonCollisionBySimilarityWidget, CrossCorrelogramsWidget, @@ -75,6 +85,8 @@ UnitTemplatesWidget, UnitWaveformDensityMapWidget, UnitWaveformsWidget, + UpsetPlotWidget, + WaveformOverlayWidget, StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances,