Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions src/spikeinterface/postprocessing/template_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def _select_extension_data(self, unit_ids):
# filter metrics dataframe
unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids)
new_similarity = self.data["similarity"][unit_indices][:, unit_indices]
return dict(similarity=new_similarity)
new_lags = self.data["lags"][unit_indices][:, unit_indices]
return dict(similarity=new_similarity, lags=new_lags)

def _merge_extension_data(
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
Expand All @@ -90,7 +91,7 @@ def _merge_extension_data(
new_sorting_analyzer.sparsity.mask[keep, :], new_unit_ids, new_sorting_analyzer.channel_ids
)

new_similarity, _ = compute_similarity_with_templates_array(
new_similarity, new_lags = compute_similarity_with_templates_array(
new_templates_array,
all_templates_array,
method=self.params["method"],
Expand All @@ -101,10 +102,12 @@ def _merge_extension_data(
)

old_similarity = self.data["similarity"]
old_lags = self.data["lags"]

all_new_unit_ids = new_sorting_analyzer.unit_ids
n = all_new_unit_ids.size
similarity = np.zeros((n, n), dtype=old_similarity.dtype)
lags = np.zeros((n, n), dtype=old_lags.dtype)

local_mask = ~np.isin(all_new_unit_ids, new_unit_ids)
sub_units_ids = all_new_unit_ids[local_mask]
Expand All @@ -117,13 +120,20 @@ def _merge_extension_data(
similarity[unit_ind1, sub_units_inds] = s
similarity[sub_units_inds, unit_ind1] = s

l = self.data["lags"][old_ind1, old_units_inds]
lags[unit_ind1, sub_units_inds] = l
lags[sub_units_inds, unit_ind1] = l

# insert new similarity both way
for unit_ind, unit_id in enumerate(all_new_unit_ids):
if unit_id in new_unit_ids:
new_index = list(new_unit_ids).index(unit_id)
similarity[unit_ind, :] = new_similarity[new_index, :]
similarity[:, unit_ind] = new_similarity[new_index, :]

lags[unit_ind, :] = new_lags[new_index, :]
lags[:, unit_ind] = new_lags[new_index, :]

return dict(similarity=similarity)

def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs):
Expand All @@ -142,7 +152,7 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer,
new_sorting_analyzer.sparsity.mask[keep, :], new_unit_ids_f, new_sorting_analyzer.channel_ids
)

new_similarity, _ = compute_similarity_with_templates_array(
new_similarity, new_lags = compute_similarity_with_templates_array(
new_templates_array,
all_templates_array,
method=self.params["method"],
Expand All @@ -153,10 +163,12 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer,
)

old_similarity = self.data["similarity"]
old_lags = self.data["lags"]

all_new_unit_ids = new_sorting_analyzer.unit_ids
n = all_new_unit_ids.size
similarity = np.zeros((n, n), dtype=old_similarity.dtype)
lags = np.zeros((n, n), dtype=old_lags.dtype)

local_mask = ~np.isin(all_new_unit_ids, new_unit_ids_f)
sub_units_ids = all_new_unit_ids[local_mask]
Expand All @@ -169,13 +181,20 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer,
similarity[unit_ind1, sub_units_inds] = s
similarity[sub_units_inds, unit_ind1] = s

l = self.data["lags"][old_ind1, old_units_inds]
lags[unit_ind1, sub_units_inds] = l
lags[sub_units_inds, unit_ind1] = l

# insert new similarity both way
for unit_ind, unit_id in enumerate(all_new_unit_ids):
if unit_id in new_unit_ids_f:
new_index = list(new_unit_ids_f).index(unit_id)
similarity[unit_ind, :] = new_similarity[new_index, :]
similarity[:, unit_ind] = new_similarity[new_index, :]

lags[unit_ind, :] = new_lags[new_index, :]
lags[:, unit_ind] = new_lags[new_index, :]

return dict(similarity=similarity)

def _run(self, verbose=False):
Expand All @@ -184,7 +203,7 @@ def _run(self, verbose=False):
self.sorting_analyzer, return_in_uV=self.sorting_analyzer.return_in_uV
)
sparsity = self.sorting_analyzer.sparsity
similarity, _ = compute_similarity_with_templates_array(
similarity, lags = compute_similarity_with_templates_array(
templates_array,
templates_array,
method=self.params["method"],
Expand All @@ -194,10 +213,14 @@ def _run(self, verbose=False):
other_sparsity=sparsity,
)
self.data["similarity"] = similarity
self.data["lags"] = lags

def _get_data(self):
return self.data["similarity"]

def get_lags(self):
return self.data["lags"]


# @alessio: compute_template_similarity() is now one inner SortingAnalyzer only
register_result_extension(ComputeTemplateSimilarity)
Expand Down Expand Up @@ -235,9 +258,9 @@ def _compute_similarity_matrix_numpy(
tgt_templates = tgt_sliced_templates[overlapping_templates]
for gcount, j in enumerate(overlapping_templates):
# symmetric values are handled later
if same_array and j < i:
# no need exhaustive looping when same template
continue
# if same_array and j < i:
# no need exhaustive looping when same template
# continue
src = src_template[:, local_mask[j]].reshape(1, -1)
tgt = (tgt_templates[gcount][:, local_mask[j]]).reshape(1, -1)

Expand All @@ -259,10 +282,8 @@ def _compute_similarity_matrix_numpy(
distances[count, i, j] = 1 - distances[count, i, j]

if same_array:
distances[count, j, i] = distances[count, i, j]
distances[num_shifts_both_sides - count - 1, j, i] = distances[count, i, j]

if same_array and num_shifts != 0:
distances[num_shifts_both_sides - count - 1] = distances[count].T
return distances


Expand Down Expand Up @@ -332,9 +353,9 @@ def _compute_similarity_matrix_numba(

j = overlapping_templates[gcount]
# symmetric values are handled later
if same_array and j < i:
# no need exhaustive looping when same template
continue
# if same_array and j < i:
# no need exhaustive looping when same template
# continue
src = src_template[:, local_mask[j]].flatten()
tgt = (tgt_templates[gcount][:, local_mask[j]]).flatten()

Expand Down Expand Up @@ -370,10 +391,10 @@ def _compute_similarity_matrix_numba(
distances[count, i, j] = 1 - distances[count, i, j]

if same_array:
distances[count, j, i] = distances[count, i, j]
distances[num_shifts_both_sides - count - 1, j, i] = distances[count, i, j]

if same_array and num_shifts != 0:
distances[num_shifts_both_sides - count - 1] = distances[count].T
# if same_array and num_shifts != 0:
# distances[num_shifts_both_sides - count - 1] = distances[count].T

return distances

Expand Down Expand Up @@ -447,7 +468,7 @@ def compute_similarity_with_templates_array(
distances = np.min(distances, axis=0)
similarity = 1 - distances

return similarity, lags
return similarity, lags.astype(np.int32)


def compute_template_similarity_by_pair(
Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/sorters/internal/lupin.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_selection import select_peaks
from spikeinterface.sortingcomponents.clustering.main import find_clusters_from_peaks, clustering_methods
from spikeinterface.sortingcomponents.tools import remove_empty_templates
from spikeinterface.preprocessing import correct_motion
from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording
from spikeinterface.sortingcomponents.tools import clean_templates, compute_sparsity_from_peaks_and_label
Expand Down