-
Notifications
You must be signed in to change notification settings - Fork 240
Add SLAy auto-merge preset #4190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
samuelgarcia
merged 14 commits into
SpikeInterface:main
from
chrishalcrow:slay-interface
Dec 18, 2025
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
6441f50
add initial slay structure in compute_merge_unit_groups
chrishalcrow 9ff3ef0
add all scores
chrishalcrow 8ed4d81
add docs
chrishalcrow 45656ce
add template_similarity to slay_score requirements
chrishalcrow 3994547
respond to alessio
chrishalcrow 7571822
make default pair mask triu
chrishalcrow 9c9dac5
Propagate precomputed pairwise similarity to slay
alejoe91 4596915
Merge branch 'main' into slay-interface
alejoe91 681c16b
Merge branch 'main' into slay-interface
alejoe91 f5fc2ed
Merge branch 'main' into slay-interface
samuelgarcia 0919fc0
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 3ec5d44
Add slay to auto_merge tests
alejoe91 49dca68
Merge branch 'main' into slay-interface
chrishalcrow 5efd2da
add more docs to docstrings
chrishalcrow File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,6 +53,10 @@ | |
| "knn", | ||
| "quality_score", | ||
| ], | ||
| "slay": [ | ||
| "template_similarity", | ||
| "slay_score", | ||
| ], | ||
| } | ||
|
|
||
| _required_extensions = { | ||
|
|
@@ -61,6 +65,7 @@ | |
| "snr": ["templates", "noise_levels"], | ||
| "template_similarity": ["templates", "template_similarity"], | ||
| "knn": ["templates", "spike_locations", "spike_amplitudes"], | ||
| "slay_score": ["correlograms", "template_similarity"], | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -85,6 +90,7 @@ | |
| "censored_period_ms": 0.3, | ||
| }, | ||
| "quality_score": {"firing_contamination_balance": 1.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, | ||
| "slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5}, | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -114,6 +120,8 @@ def compute_merge_unit_groups( | |
| * "cross_contamination": the cross-contamination is not significant (`cc_thresh` and `p_value`) | ||
| * "knn": the two units are close in the feature space | ||
| * "quality_score": the unit "quality score" is increased after the merge | ||
| * "slay_score": a combined score, factoring in a template similarity measure, a cross-correlation significance measure | ||
| and a sliding refractory period violation measure, based on the SLAy algorithm. | ||
|
|
||
| The "quality score" factors in the increase in firing rate (**f**) due to the merge and a possible increase in | ||
| contamination (**C**), wheighted by a factor **k** (`firing_contamination_balance`). | ||
|
|
@@ -145,6 +153,9 @@ def compute_merge_unit_groups( | |
| * | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN. | ||
| | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations", | ||
| | "knn", "quality_score" | ||
| * | "slay": an approximate implementation of SLAy, original implementation at https://github.com/saikoukunt/SLAy. | ||
| | The spikeinterface version uses `template_similarity`, rather than an auto-encoder. | ||
| | It uses the following steps: "template_similarity", "slay_score" | ||
|
|
||
| If `preset` is None, you can specify the steps manually with the `steps` parameter. | ||
| resolve_graph : bool, default: True | ||
|
|
@@ -363,6 +374,14 @@ def compute_merge_unit_groups( | |
| ) | ||
| outs["pairs_decreased_score"] = pairs_decreased_score | ||
|
|
||
| elif step == "slay_score": | ||
|
|
||
| M_ij = compute_slay_matrix( | ||
| sorting_analyzer, params["k1"], params["k2"], templates_diff=outs["templates_diff"], pair_mask=pair_mask | ||
| ) | ||
|
|
||
| pair_mask = pair_mask & (M_ij > params["slay_threshold"]) | ||
|
|
||
| # FINAL STEP : create the final list from pair_mask boolean matrix | ||
| ind1, ind2 = np.nonzero(pair_mask) | ||
| merge_unit_groups = list(zip(unit_ids[ind1], unit_ids[ind2])) | ||
|
|
@@ -550,6 +569,7 @@ def get_potential_auto_merge( | |
| * "cross_contamination": the cross-contamination is not significant (`cc_thresh` and `p_value`) | ||
| * "knn": the two units are close in the feature space | ||
| * "quality_score": the unit "quality score" is increased after the merge | ||
| * "slay_score": a combined score, factoring in a template similarity measure, a cross-correlation significance measure and a sliding refractory period violation measure, based on the SLAy algorithm. | ||
|
|
||
| The "quality score" factors in the increase in firing rate (**f**) due to the merge and a possible increase in | ||
| contamination (**C**), wheighted by a factor **k** (`firing_contamination_balance`). | ||
|
|
@@ -566,7 +586,7 @@ def get_potential_auto_merge( | |
| ---------- | ||
| sorting_analyzer : SortingAnalyzer | ||
| The SortingAnalyzer | ||
| preset : "similarity_correlograms" | "x_contaminations" | "temporal_splits" | "feature_neighbors" | None, default: "similarity_correlograms" | ||
| preset : "similarity_correlograms" | "x_contaminations" | "temporal_splits" | "feature_neighbors" | "slay" | None, default: "similarity_correlograms" | ||
| The preset to use for the auto-merge. Presets combine different steps into a recipe and focus on: | ||
|
|
||
| * | "similarity_correlograms": mainly focused on template similarity and correlograms. | ||
|
|
@@ -581,6 +601,9 @@ def get_potential_auto_merge( | |
| * | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN. | ||
| | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations", | ||
| | "knn", "quality_score" | ||
| * | "slay": an approximate implementation of SLAy, original implementation at https://github.com/saikoukunt/SLAy. | ||
| | The spikeinterface version uses `template_similarity`, rather than an auto-encoder. | ||
| | It uses the following steps: "template_similarity", "slay_score" | ||
|
|
||
| If `preset` is None, you can specify the steps manually with the `steps` parameter. | ||
| resolve_graph : bool, default: False | ||
|
|
@@ -1525,3 +1548,251 @@ def estimate_cross_contamination( | |
| ) | ||
|
|
||
| return estimation, p_value | ||
|
|
||
|
|
||
| def compute_slay_matrix( | ||
| sorting_analyzer: SortingAnalyzer, | ||
| k1: float, | ||
| k2: float, | ||
| templates_diff: np.ndarray | None, | ||
| pair_mask: np.ndarray | None = None, | ||
| ): | ||
| """ | ||
| Computes the "merge decision metric" from the SLAy method, made from combining | ||
| a template similarity measure, a cross-correlation significance measure and a | ||
| sliding refractory period violation measure. A large M suggests that two | ||
| units should be merged. | ||
|
|
||
| Paramters | ||
| --------- | ||
| sorting_analyzer : SortingAnalyzer | ||
| The sorting analyzer object containing the spike sorting data | ||
| k1 : float | ||
| Coefficient determining the importance of the cross-correlation significance | ||
| k2 : float | ||
| Coefficient determining the importance of the sliding rp violation | ||
| templates_diff : np.ndarray | None | ||
| Pre-computed template similarity difference matrix. If None, it will be retrieved from the sorting_analyzer. | ||
| pair_mask : None | np.ndarray, default: None | ||
| A bool matrix describing which pairs are possible merges based on previous steps | ||
|
|
||
|
|
||
| References | ||
| ---------- | ||
| Based on computation originally implemented in SLAy [Koukuntla]_. | ||
|
|
||
| Implementation is based on one of the original implementations written by Sai Koukuntla, | ||
| found at https://github.com/saikoukunt/SLAy. | ||
| """ | ||
|
|
||
| num_units = sorting_analyzer.get_num_units() | ||
|
|
||
| if pair_mask is None: | ||
| pair_mask = np.triu(np.arange(num_units), 1) > 0 | ||
|
|
||
| if templates_diff is not None: | ||
| sigma_ij = 1 - templates_diff | ||
| else: | ||
| sigma_ij = sorting_analyzer.get_extension("template_similarity").get_data() | ||
| rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, pair_mask) | ||
|
|
||
| M_ij = sigma_ij + k1 * rho_ij - k2 * eta_ij | ||
|
|
||
| return M_ij | ||
|
|
||
|
|
||
| def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarray): | ||
| """ | ||
| Computes a cross-correlation significance measure and a sliding refractory period violation | ||
| measure for all units in the `sorting_analyzer`. | ||
|
|
||
| Paramters | ||
| --------- | ||
| sorting_analyzer : SortingAnalyzer | ||
| The sorting analyzer object containing the spike sorting data | ||
| pair_mask : np.ndarray | ||
| A bool matrix describing which pairs are possible merges based on previous steps | ||
| """ | ||
|
|
||
| correlograms_extension = sorting_analyzer.get_extension("correlograms") | ||
| ccgs, _ = correlograms_extension.get_data() | ||
|
|
||
| # convert to seconds for SLAy functions | ||
| bin_size_ms = correlograms_extension.params["bin_ms"] | ||
|
|
||
| rho_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)]) | ||
| eta_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)]) | ||
|
|
||
| for unit_index_1, _ in enumerate(sorting_analyzer.unit_ids): | ||
| for unit_index_2, _ in enumerate(sorting_analyzer.unit_ids): | ||
|
|
||
| # Don't waste time computing the other metrics if units not candidates merges | ||
| if not pair_mask[unit_index_1, unit_index_2]: | ||
| continue | ||
|
|
||
| xgram = ccgs[unit_index_1, unit_index_2, :] | ||
|
|
||
| rho_ij[unit_index_1, unit_index_2] = _compute_xcorr_pair( | ||
| xgram, bin_size_s=bin_size_ms / 1000, min_xcorr_rate=0 | ||
| ) | ||
| eta_ij[unit_index_1, unit_index_2] = _sliding_RP_viol_pair(xgram, bin_size_ms=bin_size_ms) | ||
|
|
||
| return rho_ij, eta_ij | ||
|
|
||
|
|
||
| def _compute_xcorr_pair( | ||
| xgram, | ||
| bin_size_s: float, | ||
| min_xcorr_rate: float, | ||
| ) -> float: | ||
| """ | ||
| Calculates a cross-correlation significance metric for a cluster pair. | ||
|
|
||
| Uses the wasserstein distance between an observed cross-correlogram and a null | ||
| distribution as an estimate of how significant the dependence between | ||
| two neurons is. Low spike count cross-correlograms have large wasserstein | ||
| distances from null by chance, so we first try to expand the window size. If | ||
| that fails to yield enough spikes, we apply a penalty to the metric. | ||
|
|
||
| Ported from https://github.com/saikoukunt/SLAy. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| xgram : np.array | ||
| The raw cross-correlogram for the cluster pair. | ||
| bin_size_s : float | ||
| The width in seconds of the bin size of the input ccgs. | ||
| min_xcorr_rate : float | ||
| The minimum ccg firing rate in Hz. | ||
|
|
||
| Returns | ||
| ------- | ||
| sig : float | ||
| The calculated cross-correlation significance metric. | ||
| """ | ||
|
|
||
| from scipy.signal import butter, find_peaks_cwt, sosfiltfilt | ||
| from scipy.stats import wasserstein_distance | ||
|
|
||
| # calculate low-pass filtered second derivative of ccg | ||
| fs = 1 / bin_size_s | ||
| cutoff_freq = 100 | ||
| nyqist = fs / 2 | ||
| cutoff = cutoff_freq / nyqist | ||
| peak_width = 0.002 / bin_size_s | ||
|
|
||
| xgram_2d = np.diff(xgram, 2) | ||
| sos = butter(4, cutoff, output="sos") | ||
| xgram_2d = sosfiltfilt(sos, xgram_2d) | ||
|
|
||
| if xgram.sum() == 0: | ||
| return 0 | ||
|
|
||
| # find negative peaks of second derivative of ccg, these are the edges of dips in ccg | ||
| peaks = find_peaks_cwt(-xgram_2d, peak_width, noise_perc=90) + 1 | ||
| # if no peaks are found, return a very low significance | ||
| if peaks.shape[0] == 0: | ||
| return -4 | ||
| peaks = np.abs(peaks - xgram.shape[0] / 2) | ||
| peaks = peaks[peaks > 0.5 * peak_width] | ||
| min_peaks = np.sort(peaks) | ||
|
|
||
| # start with peaks closest to 0 and move to the next set of peaks if the event count is too low | ||
| window_width = min_peaks * 1.5 | ||
| starts = np.maximum(xgram.shape[0] / 2 - window_width, 0) | ||
| ends = np.minimum(xgram.shape[0] / 2 + window_width, xgram.shape[0] - 1) | ||
| ind = 0 | ||
| xgram_window = xgram[int(starts[0]) : int(ends[0] + 1)] | ||
| xgram_sum = xgram_window.sum() | ||
| window_size = xgram_window.shape[0] * bin_size_s | ||
| while (xgram_sum < (min_xcorr_rate * window_size * 10)) and (ind < starts.shape[0]): | ||
| xgram_window = xgram[int(starts[ind]) : int(ends[ind] + 1)] | ||
| xgram_sum = xgram_window.sum() | ||
| window_size = xgram_window.shape[0] * bin_size_s | ||
| ind += 1 | ||
| # use the whole ccg if peak finding fails | ||
| if ind == starts.shape[0]: | ||
| xgram_window = xgram | ||
|
|
||
| # TODO: was getting error messges when xgram_window was all zero. Why was this happening? | ||
| if np.abs(xgram_window).sum() == 0: | ||
| return 0 | ||
|
|
||
| sig = ( | ||
| wasserstein_distance( | ||
| np.arange(xgram_window.shape[0]) / xgram_window.shape[0], | ||
| np.arange(xgram_window.shape[0]) / xgram_window.shape[0], | ||
| xgram_window, | ||
| np.ones_like(xgram_window), | ||
| ) | ||
| * 4 | ||
| ) | ||
|
|
||
| if xgram_window.sum() < (min_xcorr_rate * window_size): | ||
| sig *= (xgram_window.sum() / (min_xcorr_rate * window_size)) ** 2 | ||
|
|
||
| # if sig < 0.04 and xgram_window.sum() < (min_xcorr_rate * window_size): | ||
| if xgram_window.sum() < (min_xcorr_rate / 4 * window_size): | ||
| sig = -4 # don't merge if the event count is way too low | ||
|
|
||
| return sig | ||
|
|
||
|
|
||
| def _sliding_RP_viol_pair( | ||
| correlogram, | ||
| bin_size_ms: float, | ||
| accept_threshold: float = 0.15, | ||
| ) -> float: | ||
| """ | ||
| Calculate the sliding refractory period violation confidence for a cluster. | ||
|
|
||
| Ported from https://github.com/saikoukunt/SLAy. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| correlogram : np.array | ||
| The auto-correlogram of the cluster. | ||
| bin_size_ms : float | ||
| The width in ms of the bin size of the input ccgs. | ||
| accept_threshold : float, default: 0.15 | ||
| The minimum ccg firing rate in Hz. | ||
|
|
||
| Returns | ||
| ------- | ||
| sig : float | ||
| The refractory period violation confidence for the cluster. | ||
| """ | ||
| from scipy.signal import butter, sosfiltfilt | ||
| from scipy.stats import poisson | ||
|
|
||
| # create various refractory periods sizes to test (between 0 and 20x bin size) | ||
| all_refractory_periods = np.arange(0, 21 * bin_size_ms, bin_size_ms) / 1000 | ||
| test_refractory_period_indices = np.array([1, 2, 4, 6, 8, 12, 16, 20], dtype="int8") | ||
| test_refractory_periods = [ | ||
| all_refractory_periods[test_rp_index] for test_rp_index in test_refractory_period_indices | ||
| ] | ||
|
|
||
| # calculate and avg halves of acg to ensure symmetry | ||
| # keep only second half of acg, refractory period violations are compared from the center of acg | ||
| half_len = int(correlogram.shape[0] / 2) | ||
| correlogram = (correlogram[half_len:] + correlogram[:half_len][::-1]) / 2 | ||
|
|
||
| acg_cumsum = np.cumsum(correlogram) | ||
| sum_res = acg_cumsum[test_refractory_period_indices - 1] # -1 bc 0th bin corresponds to 0-bin_size ms | ||
|
|
||
| # low-pass filter acg and use max as baseline event rate | ||
| order = 4 # Hz | ||
| cutoff_freq = 250 # Hz | ||
| fs = 1 / bin_size_ms * 1000 | ||
| nyqist = fs / 2 | ||
| cutoff = cutoff_freq / nyqist | ||
| sos = butter(order, cutoff, btype="low", output="sos") | ||
| smoothed_acg = sosfiltfilt(sos, correlogram) | ||
|
|
||
| bin_rate_max = np.max(smoothed_acg) | ||
| max_conts_max = np.array(test_refractory_periods) / bin_size_ms * 1000 * (bin_rate_max * accept_threshold) | ||
| # compute confidence of less than acceptThresh contamination at each refractory period | ||
| confs = 1 - poisson.cdf(sum_res, max_conts_max) | ||
| rp_viol = 1 - confs.max() | ||
|
Comment on lines
1792
to
1796
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| return rp_viol | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.