Skip to content

Commit 39c8e60

Browse files
committed
Merge branch 'clustering_merge_isocut' of github.com:samuelgarcia/spikeinterface into clustering_merge_isocut
2 parents fea45bb + f2f2481 commit 39c8e60

File tree

3 files changed

+31
-67
lines changed

3 files changed

+31
-67
lines changed

src/spikeinterface/sortingcomponents/clustering/merge.py

Lines changed: 14 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def merge_peak_labels_from_features(
3333
template_sparse_mask,
3434
recording,
3535
features_dict_or_folder,
36-
radius_um=70.,
36+
radius_um=70.0,
3737
method="project_distribution",
3838
method_kwargs={},
3939
**job_kwargs,
@@ -57,7 +57,6 @@ def merge_peak_labels_from_features(
5757
template_sparse_mask,
5858
recording,
5959
features_dict_or_folder,
60-
6160
# sparse_wfs,
6261
# sparse_mask,
6362
radius_um=radius_um,
@@ -66,29 +65,23 @@ def merge_peak_labels_from_features(
6665
**job_kwargs,
6766
)
6867

69-
clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids = \
68+
clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids = (
7069
_apply_pair_mask_on_labels_and_recompute_templates(
71-
pair_mask,
72-
peak_labels,
73-
unit_ids,
74-
templates_array,
75-
template_sparse_mask
70+
pair_mask, peak_labels, unit_ids, templates_array, template_sparse_mask
7671
)
72+
)
7773

7874
return clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids
7975

8076

81-
82-
8377
def find_merge_pairs_from_features(
8478
peaks,
8579
peak_labels,
8680
unit_ids,
8781
templates_array,
8882
template_sparse_mask,
8983
recording,
90-
features_dict_or_folder,
91-
84+
features_dict_or_folder,
9285
# sparse_wfs,
9386
# sparse_mask,
9487
radius_um=70,
@@ -120,7 +113,6 @@ def find_merge_pairs_from_features(
120113

121114
# compute template (no shift at this step)
122115

123-
124116
# templates = compute_template_from_sparse(
125117
# peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels, peak_shifts=None
126118
# )
@@ -131,13 +123,11 @@ def find_merge_pairs_from_features(
131123
# ms_after = features['ms_after']
132124
# svd_model = features['svd_model']
133125

134-
135126
# templates, final_sparsity_mask = get_templates_from_peaks_and_svd(
136-
# recording, peaks, peak_labels, ms_before, ms_after, svd_model, peaks_svd, sparse_mask, operator="average",
127+
# recording, peaks, peak_labels, ms_before, ms_after, svd_model, peaks_svd, sparse_mask, operator="average",
137128
# )
138129
# dense_templates_array = templates.templates_array
139130

140-
141131
labels_set = unit_ids.tolist()
142132

143133
max_chans = np.argmax(np.max(np.abs(templates_array), axis=1), axis=1)
@@ -319,8 +309,6 @@ def merge(
319309
final_shift = 0
320310
return is_merge, label0, label1, final_shift, merge_value
321311

322-
323-
324312
inds = np.concatenate([inds0, inds1])
325313
labels = np.zeros(inds.size, dtype="int")
326314
labels[inds0.size :] = 1
@@ -332,7 +320,6 @@ def merge(
332320
wfs0 = wfs[:cut, :, :]
333321
wfs1 = wfs[cut:, :, :]
334322

335-
336323
# num_samples = template0.shape[0]
337324

338325
# template0 = template0_[num_shift : num_samples - num_shift, :]
@@ -368,7 +355,6 @@ def merge(
368355
# wfs1 = wfs1_[:, best_shift : best_shift + template0.shape[0], :]
369356
# template1 = template1_[best_shift : best_shift + template0.shape[0], :]
370357

371-
372358
feat0 = wfs0.reshape(wfs0.shape[0], -1)
373359
feat1 = wfs1.reshape(wfs1.shape[0], -1)
374360
feat = np.concatenate([feat0, feat1], axis=0)
@@ -377,11 +363,11 @@ def merge(
377363

378364
if use_svd:
379365
from sklearn.decomposition import TruncatedSVD
366+
380367
n_pca_features = 3
381368
tsvd = TruncatedSVD(n_pca_features, random_state=seed)
382369
feat = tsvd.fit_transform(feat)
383370

384-
385371
if isinstance(n_pca_features, float):
386372
assert 0 < n_pca_features < 1, "n_components should be in ]0, 1["
387373
nb_dimensions = min(feat.shape[0], feat.shape[1])
@@ -416,7 +402,6 @@ def merge(
416402
# else:
417403
# feat = feat
418404

419-
420405
feat0 = feat[:cut]
421406
feat1 = feat[cut:]
422407

@@ -447,7 +432,6 @@ def merge(
447432
feat0 = feat[:cut]
448433
feat1 = feat[cut:]
449434

450-
451435
if criteria == "isocut":
452436
dipscore, cutpoint = isocut(feat)
453437
is_merge = dipscore < isocut_threshold
@@ -484,7 +468,7 @@ def merge(
484468
final_shift = 0
485469

486470
if DEBUG:
487-
# if dipscore < 4:
471+
# if dipscore < 4:
488472
import matplotlib.pyplot as plt
489473

490474
flatten_wfs0 = wfs0.swapaxes(1, 2).reshape(wfs0.shape[0], -1)
@@ -570,33 +554,25 @@ def merge_peak_labels_from_templates(
570554
)
571555
pair_mask = similarity > similarity_thresh
572556

573-
574-
clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids = \
557+
clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids = (
575558
_apply_pair_mask_on_labels_and_recompute_templates(
576-
pair_mask,
577-
peak_labels,
578-
unit_ids,
579-
templates_array,
580-
template_sparse_mask
559+
pair_mask, peak_labels, unit_ids, templates_array, template_sparse_mask
581560
)
561+
)
582562

583563
return clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids
584564

565+
585566
def _apply_pair_mask_on_labels_and_recompute_templates(
586-
pair_mask,
587-
peak_labels,
588-
unit_ids,
589-
templates_array,
590-
template_sparse_mask
567+
pair_mask, peak_labels, unit_ids, templates_array, template_sparse_mask
591568
):
592569
"""
593570
Resolve pairs graph.
594571
Apply to new labels.
595572
Recompute templates.
596573
"""
597-
598-
from scipy.sparse.csgraph import connected_components
599574

575+
from scipy.sparse.csgraph import connected_components
600576

601577
keep_template = np.ones(templates_array.shape[0], dtype="bool")
602578
clean_labels = peak_labels.copy()
@@ -638,4 +614,3 @@ def _apply_pair_mask_on_labels_and_recompute_templates(
638614
merge_sparsity_mask = merge_sparsity_mask[keep_template, :]
639615

640616
return clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids
641-

src/spikeinterface/sortingcomponents/clustering/tdc.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from spikeinterface.sortingcomponents.clustering.split import split_clusters
2323

2424
# from spikeinterface.sortingcomponents.clustering.merge import merge_clusters
25-
from spikeinterface.sortingcomponents.clustering.merge import merge_peak_labels_from_templates, merge_peak_labels_from_features
25+
from spikeinterface.sortingcomponents.clustering.merge import (
26+
merge_peak_labels_from_templates,
27+
merge_peak_labels_from_features,
28+
)
2629
from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd
2730
from spikeinterface.sortingcomponents.clustering.peak_svd import extract_peaks_svd
2831

@@ -55,7 +58,6 @@ class TdcClustering:
5558
# "clusterer": "isosplit6",
5659
# "clusterer_kwargs": {},
5760
"clusterer": "isosplit",
58-
5961
"clusterer_kwargs": {
6062
"n_init": 50,
6163
"min_cluster_size": 10,
@@ -71,11 +73,7 @@ class TdcClustering:
7173
"min_size_split": 10,
7274
},
7375
"do_merge_with_features": False,
74-
"merge_features_kwargs": {
75-
"merge_radius_um":50.,
76-
"criteria": "isocut",
77-
"isocut_threshold": 2.0
78-
},
76+
"merge_features_kwargs": {"merge_radius_um": 50.0, "criteria": "isocut", "isocut_threshold": 2.0},
7977
"do_merge_with_templates": True,
8078
"merge_template_kwargs": {
8179
"similarity_metric": "l1",
@@ -123,7 +121,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
123121

124122
original_labels = peaks["channel_index"]
125123

126-
127124
clusterer = params["split"]["clusterer"]
128125
clusterer_kwargs = params["split"]["clusterer_kwargs"]
129126

@@ -186,29 +183,24 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
186183
radius_um=merge_radius_um,
187184
method="project_distribution",
188185
method_kwargs=dict(
189-
feature_name="peaks_svd",
190-
waveforms_sparse_mask=sparse_mask,
191-
**merge_features_kwargs
192-
),
193-
**job_kwargs
186+
feature_name="peaks_svd", waveforms_sparse_mask=sparse_mask, **merge_features_kwargs
187+
),
188+
**job_kwargs,
194189
)
195190
else:
196191
post_merge_label1 = post_split_label.copy()
197-
192+
198193
if params["do_merge_with_templates"]:
199-
post_merge_label2, templates_array, template_sparse_mask, unit_ids = (
200-
merge_peak_labels_from_templates(
201-
peaks,
202-
post_merge_label1,
203-
unit_ids,
204-
templates_array,
205-
template_sparse_mask,
206-
**params["merge_template_kwargs"],
207-
)
194+
post_merge_label2, templates_array, template_sparse_mask, unit_ids = merge_peak_labels_from_templates(
195+
peaks,
196+
post_merge_label1,
197+
unit_ids,
198+
templates_array,
199+
template_sparse_mask,
200+
**params["merge_template_kwargs"],
208201
)
209202
else:
210203
post_merge_label2 = post_merge_label1.copy()
211-
212204

213205
dense_templates = Templates(
214206
templates_array=templates_array,
@@ -221,7 +213,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
221213
is_in_uV=False,
222214
)
223215

224-
225216
sparsity = ChannelSparsity(template_sparse_mask, unit_ids, recording.channel_ids)
226217
templates = dense_templates.to_sparse(sparsity)
227218

@@ -243,9 +234,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
243234
templates = templates.select_units(labels_set)
244235
labels_set = templates.unit_ids
245236

246-
247237
more_outs = dict(
248238
templates=templates,
249239
)
250240
return labels_set, final_peak_labels, more_outs
251-

src/spikeinterface/sortingcomponents/clustering/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def aggregate_sparse_features(peaks, peak_indices, sparse_feature, sparse_target
127127
# templates: numpy.array
128128
# Templates shape : (len(labels_set), num_samples, total_channels)
129129
# """
130-
130+
131131
# # NOTE SAM I think this is wrong, we should remove
132132

133133
# n = len(labels_set)

0 commit comments

Comments
 (0)