Skip to content

Commit 32578f6

Browse files
ttngu207claude
andcommitted
Fix sparse template handling in kilosort reader for SI-exported Phy data
SpikeInterface's export_to_phy exports sparse templates by default (since v0.101.0), with templates.npy shape (n_templates, n_samples, max_sparse_channels) instead of the full (n_templates, n_samples, n_channels) format from native Kilosort. The companion file templates_ind.npy maps (template_idx, sparse_channel_idx) to actual channel indices, with -1 indicating padding. This fix updates get_best_channel() and extract_spike_depths() to: - Check if templates_ind exists (indicates SI-exported sparse format) - Use templates_ind to map sparse indices to actual channel indices - Fall back to original behavior for native Kilosort (dense) format Without this fix, spike_sites and best_channel values are incorrect when reading SI-exported Phy curations, as argmax returns indices into the sparse representation rather than actual channel indices. Related: dj-sciops/nei_nienborg#111 (Issue #2) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent c4fafcd commit 32578f6

File tree

1 file changed

+43
-3
lines changed

1 file changed

+43
-3
lines changed

element_array_ephys/readers/kilosort.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,28 @@ def get_best_channel(self, unit):
151151
]
152152
channel_templates = self.data["templates"][template_idx, :, :]
153153
max_channel_idx = np.abs(channel_templates).max(axis=0).argmax()
154-
max_channel = self.data["channel_map"][max_channel_idx]
154+
155+
# Handle sparse templates (SpikeInterface export_to_phy format)
156+
# templates_ind maps (template_idx, sparse_channel_idx) -> actual channel_idx
157+
# Value of -1 indicates padding (no channel at that position)
158+
if "templates_ind" in self.data:
159+
# Use templates_ind to get actual channel index for this template
160+
actual_channel_idx = self.data["templates_ind"][template_idx, max_channel_idx]
161+
if actual_channel_idx >= 0:
162+
max_channel = self.data["channel_map"][actual_channel_idx]
163+
else:
164+
# Fallback if the max is in a padded position (shouldn't happen normally)
165+
log.warning(
166+
f"Unit {unit}: max amplitude in padded channel position, "
167+
"falling back to first valid channel"
168+
)
169+
valid_channels = self.data["templates_ind"][template_idx]
170+
valid_idx = valid_channels[valid_channels >= 0][0]
171+
max_channel = self.data["channel_map"][valid_idx]
172+
max_channel_idx = valid_idx
173+
else:
174+
# Dense templates (native Kilosort format) - use channel_map directly
175+
max_channel = self.data["channel_map"][max_channel_idx]
155176

156177
return max_channel, max_channel_idx
157178

@@ -179,9 +200,28 @@ def extract_spike_depths(self):
179200
self._data["spike_depths"] = None
180201

181202
# ---- extract spike sites ----
203+
# For each template, find the channel with maximum amplitude
182204
max_site_ind = np.argmax(np.abs(self.data["templates"]).max(axis=1), axis=1)
183-
spike_site_ind = max_site_ind[self.data["spike_templates"]]
184-
self._data["spike_sites"] = self.data["channel_map"][spike_site_ind]
205+
206+
# Handle sparse templates (SpikeInterface export_to_phy format)
207+
# templates_ind maps (template_idx, sparse_channel_idx) -> actual channel_idx
208+
if "templates_ind" in self.data:
209+
# Map sparse indices to actual channel indices using templates_ind
210+
# templates_ind shape: (n_templates, max_sparse_channels)
211+
templates_ind = self.data["templates_ind"]
212+
# For each template, get the actual channel index at the max position
213+
actual_channel_indices = templates_ind[
214+
np.arange(len(max_site_ind)), max_site_ind
215+
]
216+
# Map spike templates to their actual channel indices
217+
spike_actual_channel_ind = actual_channel_indices[
218+
self.data["spike_templates"]
219+
]
220+
self._data["spike_sites"] = self.data["channel_map"][spike_actual_channel_ind]
221+
else:
222+
# Dense templates (native Kilosort format) - use channel_map directly
223+
spike_site_ind = max_site_ind[self.data["spike_templates"]]
224+
self._data["spike_sites"] = self.data["channel_map"][spike_site_ind]
185225

186226

187227
def extract_clustering_info(cluster_output_dir):

0 commit comments

Comments
 (0)