@@ -33,7 +33,7 @@ def merge_peak_labels_from_features(
3333 template_sparse_mask ,
3434 recording ,
3535 features_dict_or_folder ,
36- radius_um = 70. ,
36+ radius_um = 70.0 ,
3737 method = "project_distribution" ,
3838 method_kwargs = {},
3939 ** job_kwargs ,
@@ -57,7 +57,6 @@ def merge_peak_labels_from_features(
5757 template_sparse_mask ,
5858 recording ,
5959 features_dict_or_folder ,
60-
6160 # sparse_wfs,
6261 # sparse_mask,
6362 radius_um = radius_um ,
@@ -66,29 +65,23 @@ def merge_peak_labels_from_features(
6665 ** job_kwargs ,
6766 )
6867
69- clean_labels , merge_template_array , merge_sparsity_mask , new_unit_ids = \
68+ clean_labels , merge_template_array , merge_sparsity_mask , new_unit_ids = (
7069 _apply_pair_mask_on_labels_and_recompute_templates (
71- pair_mask ,
72- peak_labels ,
73- unit_ids ,
74- templates_array ,
75- template_sparse_mask
70+ pair_mask , peak_labels , unit_ids , templates_array , template_sparse_mask
7671 )
72+ )
7773
7874 return clean_labels , merge_template_array , merge_sparsity_mask , new_unit_ids
7975
8076
81-
82-
8377def find_merge_pairs_from_features (
8478 peaks ,
8579 peak_labels ,
8680 unit_ids ,
8781 templates_array ,
8882 template_sparse_mask ,
8983 recording ,
90- features_dict_or_folder ,
91-
84+ features_dict_or_folder ,
9285 # sparse_wfs,
9386 # sparse_mask,
9487 radius_um = 70 ,
@@ -120,7 +113,6 @@ def find_merge_pairs_from_features(
120113
121114 # compute template (no shift at this step)
122115
123-
124116 # templates = compute_template_from_sparse(
125117 # peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels, peak_shifts=None
126118 # )
@@ -131,13 +123,11 @@ def find_merge_pairs_from_features(
131123 # ms_after = features['ms_after']
132124 # svd_model = features['svd_model']
133125
134-
135126 # templates, final_sparsity_mask = get_templates_from_peaks_and_svd(
136- # recording, peaks, peak_labels, ms_before, ms_after, svd_model, peaks_svd, sparse_mask, operator="average",
127+ # recording, peaks, peak_labels, ms_before, ms_after, svd_model, peaks_svd, sparse_mask, operator="average",
137128 # )
138129 # dense_templates_array = templates.templates_array
139130
140-
141131 labels_set = unit_ids .tolist ()
142132
143133 max_chans = np .argmax (np .max (np .abs (templates_array ), axis = 1 ), axis = 1 )
@@ -319,8 +309,6 @@ def merge(
319309 final_shift = 0
320310 return is_merge , label0 , label1 , final_shift , merge_value
321311
322-
323-
324312 inds = np .concatenate ([inds0 , inds1 ])
325313 labels = np .zeros (inds .size , dtype = "int" )
326314 labels [inds0 .size :] = 1
@@ -332,7 +320,6 @@ def merge(
332320 wfs0 = wfs [:cut , :, :]
333321 wfs1 = wfs [cut :, :, :]
334322
335-
336323 # num_samples = template0.shape[0]
337324
338325 # template0 = template0_[num_shift : num_samples - num_shift, :]
@@ -368,7 +355,6 @@ def merge(
368355 # wfs1 = wfs1_[:, best_shift : best_shift + template0.shape[0], :]
369356 # template1 = template1_[best_shift : best_shift + template0.shape[0], :]
370357
371-
372358 feat0 = wfs0 .reshape (wfs0 .shape [0 ], - 1 )
373359 feat1 = wfs1 .reshape (wfs1 .shape [0 ], - 1 )
374360 feat = np .concatenate ([feat0 , feat1 ], axis = 0 )
@@ -377,11 +363,11 @@ def merge(
377363
378364 if use_svd :
379365 from sklearn .decomposition import TruncatedSVD
366+
380367 n_pca_features = 3
381368 tsvd = TruncatedSVD (n_pca_features , random_state = seed )
382369 feat = tsvd .fit_transform (feat )
383370
384-
385371 if isinstance (n_pca_features , float ):
386372 assert 0 < n_pca_features < 1 , "n_components should be in ]0, 1["
387373 nb_dimensions = min (feat .shape [0 ], feat .shape [1 ])
@@ -416,7 +402,6 @@ def merge(
416402 # else:
417403 # feat = feat
418404
419-
420405 feat0 = feat [:cut ]
421406 feat1 = feat [cut :]
422407
@@ -447,7 +432,6 @@ def merge(
447432 feat0 = feat [:cut ]
448433 feat1 = feat [cut :]
449434
450-
451435 if criteria == "isocut" :
452436 dipscore , cutpoint = isocut (feat )
453437 is_merge = dipscore < isocut_threshold
@@ -484,7 +468,7 @@ def merge(
484468 final_shift = 0
485469
486470 if DEBUG :
487- # if dipscore < 4:
471+ # if dipscore < 4:
488472 import matplotlib .pyplot as plt
489473
490474 flatten_wfs0 = wfs0 .swapaxes (1 , 2 ).reshape (wfs0 .shape [0 ], - 1 )
@@ -570,33 +554,25 @@ def merge_peak_labels_from_templates(
570554 )
571555 pair_mask = similarity > similarity_thresh
572556
573-
574- clean_labels , merge_template_array , merge_sparsity_mask , new_unit_ids = \
557+ clean_labels , merge_template_array , merge_sparsity_mask , new_unit_ids = (
575558 _apply_pair_mask_on_labels_and_recompute_templates (
576- pair_mask ,
577- peak_labels ,
578- unit_ids ,
579- templates_array ,
580- template_sparse_mask
559+ pair_mask , peak_labels , unit_ids , templates_array , template_sparse_mask
581560 )
561+ )
582562
583563 return clean_labels , merge_template_array , merge_sparsity_mask , new_unit_ids
584564
565+
585566def _apply_pair_mask_on_labels_and_recompute_templates (
586- pair_mask ,
587- peak_labels ,
588- unit_ids ,
589- templates_array ,
590- template_sparse_mask
567+ pair_mask , peak_labels , unit_ids , templates_array , template_sparse_mask
591568):
592569 """
593570 Resolve pairs graph.
594571 Apply to new labels.
595572 Recompute templates.
596573 """
597-
598- from scipy .sparse .csgraph import connected_components
599574
575+ from scipy .sparse .csgraph import connected_components
600576
601577 keep_template = np .ones (templates_array .shape [0 ], dtype = "bool" )
602578 clean_labels = peak_labels .copy ()
@@ -638,4 +614,3 @@ def _apply_pair_mask_on_labels_and_recompute_templates(
638614 merge_sparsity_mask = merge_sparsity_mask [keep_template , :]
639615
640616 return clean_labels , merge_template_array , merge_sparsity_mask , new_unit_ids
641-
0 commit comments