Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 188 additions & 1 deletion src/spikeinterface/extractors/phykilosortextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,25 @@

from typing import Optional
from pathlib import Path
import json

import numpy as np

from spikeinterface.core import BaseSorting, BaseSortingSegment, read_python
from spikeinterface.core import (
BaseSorting,
BaseSortingSegment,
read_python,
generate_ground_truth_recording,
ChannelSparsity,
ComputeTemplates,
create_sorting_analyzer,
SortingAnalyzer,
)
from spikeinterface.core.core_tools import define_function_from_class

from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations
from probeinterface import read_prb, Probe


class BasePhyKilosortSortingExtractor(BaseSorting):
"""Base SortingExtractor for Phy and Kilosort output folder.
Expand Down Expand Up @@ -302,3 +315,177 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove

read_phy = define_function_from_class(source_class=PhySortingExtractor, name="read_phy")
read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort")


def kilosort_output_to_analyzer(folder_path, compute_extras=False, unwhiten=True) -> SortingAnalyzer:
"""
Load kilosort output into a SortingAnalyzer.
Output from kilosort version 4.1 and above are supported.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way to check the version from the output? If not we should ask them to add on the KS repo as this would be a useful general addition. But maybe we could check that the kilosortX.log is not < 4? (IIRC that the logs are formatted in this way)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have asked on KiloSort

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you've run directly from kilosort, you'll have the version in kilosort4.log. So we could check there, and if it's not there, we have a guess...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great that sounds good, if this tool is only for kilosort4 then just checking for the existing of that log file should do (unless it's extended to other versions)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, the log files only appeared at v4.0.33. Thinking of other ways to check...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know of any files which KS2/2.5 defo don't have in their output, that KS4 does?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, I've added a _guess_kilosort_version function, to isolate this logic. Let's compare some outputs and see if we can make it reasonable.


Parameters
----------
folder_path : str or Path
Path to the output Phy folder (containing the params.py).
compute_extras : bool, default: False
Compute the extra extensions: unit_locations, correlograms, template_similarity, isi_histograms, template_metrics, quality_metrics.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the only thing missing is metrics computed from waveforms? I guess following use of this function, the user can run "compute_waveforms" after attaching a recording object?

I always thought it would be cool to load the kilosort waveforms directly in SI, but this doesn't seem to be possible for older versions. For kilosort4 (outside the scope of this PR, just a discussion) do you think we could ask Jacob to write a public kilosort function that returns the waveforms in 3d array? I guess on the kilosort4 side they are still stored as a first few components in PCA space (?). It would be cool to have an option to load the original KS4 waveforms too in this function. I always wondered exactly how they compare to those generated in spikeinterface when drift-correction is run on the kilosort side (and spikeinterface waveforms are cut-out from the preprocessed recording pre-drift correction).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost all the metrics depend on waveform features like amplitudes and locations, rather than the waveforms themselves. I don't think we miss out on anything (except the pca metrics, because the pca's are impossible to load, I think...)

Does kilosort save waveforms? My output doesn't seem to have any - just their features such as locations, amps and pcas.

Copy link
Collaborator

@JoeZiminski JoeZiminski Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's cool, okay great and makes sense. Would it be possible for the user to attach a recording and compute_waveforms afterwards if they wanted PCs?

The old versions did used to under pc_features.npy (see here), it needs to be indexed with pc_feature_ind.npy, I forget the exact process but it is possible to reconstruct them from the compressed spikes (stored as scores on the first few PCs). However this was the issue with recosntructing across versions as the procedure is different across KS versions. I'm not 100% sure about KS4, I'm actually running it on some data tomorrow so will get back!

unwhiten : bool, default: True
Unwhiten the templates computed by kilosort.

Returns
-------
sorting_analyzer : SortingAnalyzer
A SortingAnalyzer object.
"""

phy_path = Path(folder_path)

sorting = read_phy(phy_path)
sampling_frequency = sorting.sampling_frequency
duration = sorting._sorting_segments[0]._all_spikes[-1] / sampling_frequency + 1

if (phy_path / "probe.prb").is_file():
probegroup = read_prb(phy_path / "probe.prb")
probe = probegroup.probes[0]
elif (phy_path / "channel_positions.npy").is_file():
probe = Probe(si_units="um")
channel_positions = np.load(phy_path / "channel_positions.npy")
probe.set_contacts(channel_positions)
probe.set_device_channel_indices(range(probe.get_contact_count()))
else:
AssertionError("Cannot read probe layout from folder {phy_path}.")

# to make the initial analyzer, we'll use a fake recording and set it to None later
recording, _ = generate_ground_truth_recording(
probe=probe, sampling_frequency=sampling_frequency, durations=[duration]
)

sparsity = _make_sparsity_from_templates(sorting, recording, phy_path)

sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=True, sparsity=sparsity)

# first compute random spikes. These do nothing, but are needed for si-gui to run
sorting_analyzer.compute("random_spikes")

_make_templates(sorting_analyzer, phy_path, sparsity.mask, sampling_frequency, unwhiten=unwhiten)
_make_locations(sorting_analyzer, phy_path)
_make_amplitudes(sorting_analyzer, phy_path)

if compute_extras:
sorting_analyzer.compute(
{
"unit_locations": {},
"correlograms": {},
"template_similarity": {},
"isi_histograms": {},
"template_metrics": {"include_multi_channel_metrics": True},
"quality_metrics": {},
}
)

sorting_analyzer._recording = None
return sorting_analyzer


def _make_amplitudes(sa, phy_path: Path):

amplitudes_extension = ComputeSpikeAmplitudes(sa)

amps_np = np.load(phy_path / "amplitudes.npy")

amplitudes_extension.data = {}
amplitudes_extension.data["amplitudes"] = amps_np

params = {"peak_sign": "neg"}
amplitudes_extension.params = params

amplitudes_extension.run_info = {"run_completed": True}

sa.extensions["spike_amplitudes"] = amplitudes_extension


def _make_locations(sa, phy_path):

locations_extension = ComputeSpikeLocations(sa)

locs_np = np.load(phy_path / "spike_positions.npy")

num_dims = len(locs_np[0])
column_names = ["x", "y", "z"][:num_dims]
dtype = [(name, locs_np.dtype) for name in column_names]

structured_array = np.array(np.zeros(len(locs_np)), dtype=dtype)
for a, column_name in enumerate(column_names):
structured_array[column_name] = locs_np[:, a]

locations_extension.data = {}
locations_extension.data["spike_locations"] = structured_array

params = {}
locations_extension.params = params

locations_extension.run_info = {"run_completed": True}

sa.extensions["spike_locations"] = locations_extension


def _make_sparsity_from_templates(sort, rec, phy_path):

templates = np.load(phy_path / "templates.npy")

unit_ids = sort.unit_ids
channel_ids = rec.channel_ids

# The raw templates have dense dimensions (num chan)x(num units)
# but are zero on many channels, which implicitly defines the sparsity
mask = np.sum(np.abs(templates), axis=1) != 0
return ChannelSparsity(mask, unit_ids=unit_ids, channel_ids=channel_ids)


def _make_templates(sa, phy_path, mask, sampling_frequency, unwhiten=True):

template_extension = ComputeTemplates(sa)

whitened_templates = np.load(phy_path / "templates.npy")
wh_inv = np.load(phy_path / "whitening_mat_inv.npy")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is only reliable on KS4 (KS2.5-3 the inverse whitening matrix is scaled identity). It is clear from the docstring that this is only for KS4 but I'm sure people will try on other versions. Checking the log file is probably the easiest way to catch this, but another possible safety net could be to check that the wh_inv is not all zeros off diagonal.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so we could try to check the version number, then if it's less than 4, we only allow whitened templates?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would work, I am not 100% confident without checking across versions, I will try and sorted and send some older datasets tomorrow! I think indeed unwhitened would not be possible for ks2.5 and ks3

new_templates = _compute_unwhitened_templates(whitened_templates, wh_inv, mask) if unwhiten else whitened_templates

template_extension.data = {"average": new_templates}

ops_path = phy_path / "ops.npy"
if ops_path.is_file():
ops = np.load(ops_path, allow_pickle=True)

samples_before = ops.item(0).get("nt0min")
nt = ops.item(0).get("nt")

samples_after = nt - samples_before

ms_before = samples_before / (sampling_frequency // 1000)
ms_after = samples_after / (sampling_frequency // 1000)

params = {
"operators": ["average"],
"ms_before": ms_before,
"ms_after": ms_after,
"peak_sign": "neg",
}

template_extension.params = params
template_extension.run_info = {"run_completed": True}

sa.extensions["templates"] = template_extension


def _compute_unwhitened_templates(whitened_templates, wh_inv, mask):

template_shape = np.shape(whitened_templates)
new_templates = np.zeros(template_shape)

sparsity_channel_ids = [np.arange(template_shape[-1])[unit_sparsity] for unit_sparsity in mask]

for a, unit_sparsity in enumerate(sparsity_channel_ids):
for b in unit_sparsity:
for c in unit_sparsity:
new_templates[a, :, b] += wh_inv[b, c] * whitened_templates[a, :, c]

return new_templates