diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 7c0f7e3dae..a5db64261b 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -248,10 +248,11 @@ def _prepare_templates(self): else: sparsity = self.templates.sparsity.mask - units_overlaps = np.sum(np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2) - self.units_overlaps = units_overlaps > 0 + # units_overlaps = np.sum(np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2) self.unit_overlaps_indices = {} + self.units_overlaps = {} for i in range(self.num_templates): + self.units_overlaps[i] = np.sum(np.logical_and(sparsity[i, :], sparsity), axis=1) > 0 self.unit_overlaps_indices[i] = np.flatnonzero(self.units_overlaps[i]) templates_array = self.templates.get_dense_templates().copy() diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 5c15f3e9c3..bd5071f2b9 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -278,10 +278,17 @@ def from_templates(cls, params, templates): Dataclass object for aggregating channel sparsity variables together. """ visible_channels = templates.sparsity.mask - unit_overlap = np.sum( - np.logical_and(visible_channels[:, np.newaxis, :], visible_channels[np.newaxis, :, :]), axis=2 - ) - unit_overlap = unit_overlap > 0 + num_templates = templates.get_dense_templates().shape[0] + unit_overlap = np.zeros((num_templates, num_templates), dtype=bool) + + for i in range(num_templates): + unit_overlap[i] = np.sum(np.logical_and(visible_channels[i], visible_channels), axis=1) > 0 + + # unit_overlap = np.sum( + # np.logical_and(visible_channels[:, np.newaxis, :], visible_channels[np.newaxis, :, :]), axis=2 + # ) + # unit_overlap = unit_overlap > 0 + unit_overlap = np.repeat(unit_overlap, params.jitter_factor, axis=0) sparsity = cls(visible_channels=visible_channels, unit_overlap=unit_overlap) return sparsity diff --git a/src/spikeinterface/sortingcomponents/peak_detection/locally_exclusive.py b/src/spikeinterface/sortingcomponents/peak_detection/locally_exclusive.py index f8c0727c26..e68660ce83 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection/locally_exclusive.py +++ b/src/spikeinterface/sortingcomponents/peak_detection/locally_exclusive.py @@ -216,7 +216,6 @@ def __init__( device=None, radius_um=50, return_tensor=False, - random_chunk_kwargs={}, return_output=True, ): if not HAVE_TORCH: @@ -288,14 +287,13 @@ def __init__( exclude_sweep_ms=0.1, radius_um=50, noise_levels=None, - random_chunk_kwargs={}, opencl_context_kwargs={}, ): if not HAVE_PYOPENCL: raise ModuleNotFoundError('"locally_exclusive_cl" needs pyopencl which is not installed') LocallyExclusivePeakDetector.__init__( - self, recording, peak_sign, detect_threshold, exclude_sweep_ms, radius_um, noise_levels, random_chunk_kwargs + self, recording, peak_sign, detect_threshold, exclude_sweep_ms, radius_um, noise_levels ) self.executor = OpenCLDetectPeakExecutor(