diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index e7b9dee2c7..847e8648a0 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -1107,6 +1107,48 @@ def compute_amplitude_medians(sorting_analyzer, peak_sign="neg", unit_ids=None): _default_params["amplitude_median"] = dict(peak_sign="neg") +def compute_waveform_ptp_medians(sorting_analyzer, unit_ids=None): + """ + Compute median of the peak-to-peak (PTP) values of the waveforms. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + unit_ids : list or None + List of unit ids to compute the waveform PTP medians. If None, all units are used. + + Returns + ------- + all_waveform_ptp_medians : dict + Estimated waveform PTP median for each unit ID. + + References + ---------- + Inspired by bombcell folks + """ + if unit_ids is None: + unit_ids = sorting_analyzer.unit_ids + + _has_required_extensions(sorting_analyzer, metric_name="waveform_ptp_median") + + wfs_ext = sorting_analyzer.get_extension("waveforms") + extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index", mode="peak_to_peak") + all_waveform_ptp_medians = {} + + for unit_id in unit_ids: + waveforms = wfs_ext.get_waveforms_one_unit(unit_id, force_dense=True) + waveform_max_channel = waveforms[:, :, extremum_channel_indices[unit_id]] + ptps = np.ptp(waveform_max_channel, axis=1) + median_ptp = np.median(ptps) + all_waveform_ptp_medians[unit_id] = median_ptp + + return all_waveform_ptp_medians + + +_default_params["waveform_ptp_median"] = dict() + + def compute_drift_metrics( sorting_analyzer, interval_s=60, diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 5e769ab8eb..e9091d9c9d 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -11,6 +11,7 @@ "drift": ["spike_locations"], "sd_ratio": ["templates", "spike_amplitudes"], "noise_cutoff": ["spike_amplitudes"], + "waveform_ptp_median": ["templates", "waveforms"], } @@ -30,6 +31,7 @@ compute_amplitude_cv_metrics, compute_sd_ratio, compute_noise_cutoffs, + compute_waveform_ptp_medians, ) from .pca_metrics import ( @@ -64,6 +66,7 @@ "drift": compute_drift_metrics, "sd_ratio": compute_sd_ratio, "noise_cutoff": compute_noise_cutoffs, + "waveform_ptp_median": compute_waveform_ptp_medians, } @@ -96,6 +99,7 @@ "silhouette": ["silhouette"], "silhouette_full": ["silhouette_full"], "noise_cutoff": ["noise_cutoff", "noise_ratio"], + "waveform_ptp_median": ["waveform_ptp_median"], } # this dict allows us to ensure the appropriate dtype of metrics rather than allow Pandas to infer them @@ -133,4 +137,5 @@ "silhouette_full": float, "noise_cutoff": float, "noise_ratio": float, + "waveform_ptp_median": float, }