diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index cf0c72952b..aed01b6a2c 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -208,7 +208,9 @@ def _get_data(self): compute_template_similarity = ComputeTemplateSimilarity.function_factory() -def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, mask, method): +def _compute_similarity_matrix_numpy( + templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support="union" +): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] @@ -232,15 +234,16 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in range(num_templates): src_template = src_sliced_templates[i] - overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) + local_mask = get_overlapping_mask_for_one_template(i, sparsity_mask, other_sparsity_mask, support=support) + overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] for gcount, j in enumerate(overlapping_templates): # symmetric values are handled later if same_array and j < i: # no need exhaustive looping when same template continue - src = src_template[:, mask[i, j]].reshape(1, -1) - tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1) + src = src_template[:, local_mask[j]].reshape(1, -1) + tgt = (tgt_templates[gcount][:, local_mask[j]]).reshape(1, -1) if method == "l1": norm_i = np.sum(np.abs(src)) @@ -273,9 +276,12 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num import numba @numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True) - def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, mask, method): + def _compute_similarity_matrix_numba( + templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support="union" + ): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] + num_channels = templates_array.shape[2] other_num_templates = other_templates_array.shape[0] num_shifts_both_sides = 2 * num_shifts + 1 @@ -284,7 +290,6 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num # We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t # So the matrix can be computed only for negative lags and be transposed - if same_array: # optimisation when array are the same because of symetry in shift shift_loop = list(range(-num_shifts, 1)) @@ -304,7 +309,23 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in numba.prange(num_templates): src_template = src_sliced_templates[i] - overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) + + ## Ideally we would like to use this but numba does not support well function with numpy and boolean arrays + ## So we inline the function here + # local_mask = get_overlapping_mask_for_one_template(i, sparsity, other_sparsity, support=support) + + if support == "intersection": + local_mask = np.logical_and( + sparsity_mask[i, :], other_sparsity_mask + ) # shape (other_num_templates, num_channels) + elif support == "union": + local_mask = np.logical_or( + sparsity_mask[i, :], other_sparsity_mask + ) # shape (other_num_templates, num_channels) + elif support == "dense": + local_mask = np.ones((other_num_templates, num_channels), dtype=np.bool_) + + overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] for gcount in range(len(overlapping_templates)): @@ -313,8 +334,8 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num if same_array and j < i: # no need exhaustive looping when same template continue - src = src_template[:, mask[i, j]].flatten() - tgt = (tgt_templates[gcount][:, mask[i, j]]).flatten() + src = src_template[:, local_mask[j]].flatten() + tgt = (tgt_templates[gcount][:, local_mask[j]]).flatten() norm_i = 0 norm_j = 0 @@ -360,6 +381,17 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num _compute_similarity_matrix = _compute_similarity_matrix_numpy +def get_overlapping_mask_for_one_template(template_index, sparsity, other_sparsity, support="union") -> np.ndarray: + + if support == "intersection": + mask = np.logical_and(sparsity[template_index, :], other_sparsity) # shape (other_num_templates, num_channels) + elif support == "union": + mask = np.logical_or(sparsity[template_index, :], other_sparsity) # shape (other_num_templates, num_channels) + elif support == "dense": + mask = np.ones(other_sparsity.shape, dtype=bool) + return mask + + def compute_similarity_with_templates_array( templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None ): @@ -369,6 +401,8 @@ def compute_similarity_with_templates_array( all_metrics = ["cosine", "l1", "l2"] + assert support in ["dense", "union", "intersection"], "support should be either dense, union or intersection" + if method not in all_metrics: raise ValueError(f"compute_template_similarity (method {method}) not exists") @@ -378,29 +412,25 @@ def compute_similarity_with_templates_array( assert ( templates_array.shape[2] == other_templates_array.shape[2] ), "The number of channels in the templates should be the same for both arrays" - num_templates = templates_array.shape[0] + # num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] - num_channels = templates_array.shape[2] - other_num_templates = other_templates_array.shape[0] - - mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool) + # num_channels = templates_array.shape[2] + # other_num_templates = other_templates_array.shape[0] - if sparsity is not None and other_sparsity is not None: - - # make the input more flexible with either The object or the array mask + if sparsity is not None: sparsity_mask = sparsity.mask if isinstance(sparsity, ChannelSparsity) else sparsity - other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity + else: + sparsity_mask = np.ones((templates_array.shape[0], templates_array.shape[2]), dtype=bool) - if support == "intersection": - mask = np.logical_and(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :]) - elif support == "union": - mask = np.logical_and(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :]) - units_overlaps = np.sum(mask, axis=2) > 0 - mask = np.logical_or(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :]) - mask[~units_overlaps] = False + if other_sparsity is not None: + other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity + else: + other_sparsity_mask = np.ones((other_templates_array.shape[0], other_templates_array.shape[2]), dtype=bool) assert num_shifts < num_samples, "max_lag is too large" - distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method) + distances = _compute_similarity_matrix( + templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support=support + ) distances = np.min(distances, axis=0) similarity = 1 - distances diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 9a25af444c..fa7d19fcbc 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -107,10 +107,23 @@ def test_equal_results_numba(params): rng = np.random.default_rng(seed=2205) templates_array = rng.random(size=(4, 20, 5), dtype=np.float32) other_templates_array = rng.random(size=(2, 20, 5), dtype=np.float32) - mask = np.ones((4, 2, 5), dtype=bool) - - result_numpy = _compute_similarity_matrix_numba(templates_array, other_templates_array, mask=mask, **params) - result_numba = _compute_similarity_matrix_numpy(templates_array, other_templates_array, mask=mask, **params) + sparsity_mask = np.ones((4, 5), dtype=bool) + other_sparsity_mask = np.ones((2, 5), dtype=bool) + + result_numpy = _compute_similarity_matrix_numba( + templates_array, + other_templates_array, + sparsity_mask=sparsity_mask, + other_sparsity_mask=other_sparsity_mask, + **params, + ) + result_numba = _compute_similarity_matrix_numpy( + templates_array, + other_templates_array, + sparsity_mask=sparsity_mask, + other_sparsity_mask=other_sparsity_mask, + **params, + ) assert np.allclose(result_numpy, result_numba, 1e-3) diff --git a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py index 23ec9d8e4c..d64d0cae3b 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py @@ -541,7 +541,6 @@ def merge_peak_labels_from_templates( assert len(unit_ids) == templates_array.shape[0] from spikeinterface.postprocessing.template_similarity import compute_similarity_with_templates_array - from scipy.sparse.csgraph import connected_components similarity = compute_similarity_with_templates_array( templates_array,