-
Notifications
You must be signed in to change notification settings - Fork 230
Add Kilosort output to SortingAnalyzer helper function #4202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
1f1c282
1162428
65d6b21
8105912
39aa298
10ed2c4
8fe91b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,12 +2,25 @@ | |
|
|
||
| from typing import Optional | ||
| from pathlib import Path | ||
| import json | ||
|
|
||
| import numpy as np | ||
|
|
||
| from spikeinterface.core import BaseSorting, BaseSortingSegment, read_python | ||
| from spikeinterface.core import ( | ||
| BaseSorting, | ||
| BaseSortingSegment, | ||
| read_python, | ||
| generate_ground_truth_recording, | ||
| ChannelSparsity, | ||
| ComputeTemplates, | ||
| create_sorting_analyzer, | ||
| SortingAnalyzer, | ||
| ) | ||
| from spikeinterface.core.core_tools import define_function_from_class | ||
|
|
||
| from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations | ||
| from probeinterface import read_prb, Probe | ||
|
|
||
|
|
||
| class BasePhyKilosortSortingExtractor(BaseSorting): | ||
| """Base SortingExtractor for Phy and Kilosort output folder. | ||
|
|
@@ -302,3 +315,177 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove | |
|
|
||
| read_phy = define_function_from_class(source_class=PhySortingExtractor, name="read_phy") | ||
| read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort") | ||
|
|
||
|
|
||
| def kilosort_output_to_analyzer(folder_path, compute_extras=False, unwhiten=True) -> SortingAnalyzer: | ||
| """ | ||
| Load kilosort output into a SortingAnalyzer. | ||
| Output from kilosort version 4.1 and above are supported. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| folder_path : str or Path | ||
| Path to the output Phy folder (containing the params.py). | ||
| compute_extras : bool, default: False | ||
| Compute the extra extensions: unit_locations, correlograms, template_similarity, isi_histograms, template_metrics, quality_metrics. | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| unwhiten : bool, default: True | ||
| Unwhiten the templates computed by kilosort. | ||
|
|
||
| Returns | ||
| ------- | ||
| sorting_analyzer : SortingAnalyzer | ||
| A SortingAnalyzer object. | ||
| """ | ||
|
|
||
| phy_path = Path(folder_path) | ||
|
|
||
| sorting = read_phy(phy_path) | ||
| sampling_frequency = sorting.sampling_frequency | ||
| duration = sorting._sorting_segments[0]._all_spikes[-1] / sampling_frequency + 1 | ||
chrishalcrow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if (phy_path / "probe.prb").is_file(): | ||
| probegroup = read_prb(phy_path / "probe.prb") | ||
| probe = probegroup.probes[0] | ||
chrishalcrow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| elif (phy_path / "channel_positions.npy").is_file(): | ||
| probe = Probe(si_units="um") | ||
| channel_positions = np.load(phy_path / "channel_positions.npy") | ||
| probe.set_contacts(channel_positions) | ||
| probe.set_device_channel_indices(range(probe.get_contact_count())) | ||
| else: | ||
| AssertionError("Cannot read probe layout from folder {phy_path}.") | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # to make the initial analyzer, we'll use a fake recording and set it to None later | ||
| recording, _ = generate_ground_truth_recording( | ||
| probe=probe, sampling_frequency=sampling_frequency, durations=[duration] | ||
| ) | ||
|
|
||
| sparsity = _make_sparsity_from_templates(sorting, recording, phy_path) | ||
|
|
||
| sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=True, sparsity=sparsity) | ||
|
|
||
| # first compute random spikes. These do nothing, but are needed for si-gui to run | ||
| sorting_analyzer.compute("random_spikes") | ||
chrishalcrow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| _make_templates(sorting_analyzer, phy_path, sparsity.mask, sampling_frequency, unwhiten=unwhiten) | ||
| _make_locations(sorting_analyzer, phy_path) | ||
| _make_amplitudes(sorting_analyzer, phy_path) | ||
|
|
||
| if compute_extras: | ||
| sorting_analyzer.compute( | ||
| { | ||
| "unit_locations": {}, | ||
| "correlograms": {}, | ||
| "template_similarity": {}, | ||
| "isi_histograms": {}, | ||
| "template_metrics": {"include_multi_channel_metrics": True}, | ||
| "quality_metrics": {}, | ||
| } | ||
| ) | ||
|
|
||
| sorting_analyzer._recording = None | ||
| return sorting_analyzer | ||
|
|
||
|
|
||
| def _make_amplitudes(sa, phy_path: Path): | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| amplitudes_extension = ComputeSpikeAmplitudes(sa) | ||
|
|
||
| amps_np = np.load(phy_path / "amplitudes.npy") | ||
|
|
||
| amplitudes_extension.data = {} | ||
| amplitudes_extension.data["amplitudes"] = amps_np | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| params = {"peak_sign": "neg"} | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| amplitudes_extension.params = params | ||
|
|
||
| amplitudes_extension.run_info = {"run_completed": True} | ||
|
|
||
| sa.extensions["spike_amplitudes"] = amplitudes_extension | ||
|
|
||
|
|
||
| def _make_locations(sa, phy_path): | ||
|
|
||
| locations_extension = ComputeSpikeLocations(sa) | ||
|
|
||
| locs_np = np.load(phy_path / "spike_positions.npy") | ||
|
|
||
| num_dims = len(locs_np[0]) | ||
| column_names = ["x", "y", "z"][:num_dims] | ||
| dtype = [(name, locs_np.dtype) for name in column_names] | ||
|
|
||
| structured_array = np.array(np.zeros(len(locs_np)), dtype=dtype) | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| for a, column_name in enumerate(column_names): | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| structured_array[column_name] = locs_np[:, a] | ||
|
|
||
| locations_extension.data = {} | ||
| locations_extension.data["spike_locations"] = structured_array | ||
|
|
||
| params = {} | ||
| locations_extension.params = params | ||
|
|
||
| locations_extension.run_info = {"run_completed": True} | ||
chrishalcrow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| sa.extensions["spike_locations"] = locations_extension | ||
|
|
||
|
|
||
| def _make_sparsity_from_templates(sort, rec, phy_path): | ||
|
|
||
| templates = np.load(phy_path / "templates.npy") | ||
|
|
||
| unit_ids = sort.unit_ids | ||
| channel_ids = rec.channel_ids | ||
|
|
||
| # The raw templates have dense dimensions (num chan)x(num units) | ||
| # but are zero on many channels, which implicitly defines the sparsity | ||
| mask = np.sum(np.abs(templates), axis=1) != 0 | ||
| return ChannelSparsity(mask, unit_ids=unit_ids, channel_ids=channel_ids) | ||
|
|
||
|
|
||
| def _make_templates(sa, phy_path, mask, sampling_frequency, unwhiten=True): | ||
|
|
||
| template_extension = ComputeTemplates(sa) | ||
|
|
||
| whitened_templates = np.load(phy_path / "templates.npy") | ||
| wh_inv = np.load(phy_path / "whitening_mat_inv.npy") | ||
|
||
| new_templates = _compute_unwhitened_templates(whitened_templates, wh_inv, mask) if unwhiten else whitened_templates | ||
|
|
||
| template_extension.data = {"average": new_templates} | ||
|
|
||
| ops_path = phy_path / "ops.npy" | ||
| if ops_path.is_file(): | ||
chrishalcrow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ops = np.load(ops_path, allow_pickle=True) | ||
|
|
||
| samples_before = ops.item(0).get("nt0min") | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| nt = ops.item(0).get("nt") | ||
|
|
||
| samples_after = nt - samples_before | ||
|
|
||
| ms_before = samples_before / (sampling_frequency // 1000) | ||
| ms_after = samples_after / (sampling_frequency // 1000) | ||
|
|
||
| params = { | ||
| "operators": ["average"], | ||
| "ms_before": ms_before, | ||
| "ms_after": ms_after, | ||
| "peak_sign": "neg", | ||
| } | ||
|
|
||
| template_extension.params = params | ||
| template_extension.run_info = {"run_completed": True} | ||
|
|
||
| sa.extensions["templates"] = template_extension | ||
|
|
||
|
|
||
| def _compute_unwhitened_templates(whitened_templates, wh_inv, mask): | ||
|
|
||
| template_shape = np.shape(whitened_templates) | ||
| new_templates = np.zeros(template_shape) | ||
|
|
||
| sparsity_channel_ids = [np.arange(template_shape[-1])[unit_sparsity] for unit_sparsity in mask] | ||
|
|
||
| for a, unit_sparsity in enumerate(sparsity_channel_ids): | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| for b in unit_sparsity: | ||
| for c in unit_sparsity: | ||
| new_templates[a, :, b] += wh_inv[b, c] * whitened_templates[a, :, c] | ||
|
|
||
| return new_templates | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any way to check the version from the output? If not we should ask them to add on the KS repo as this would be a useful general addition. But maybe we could check that the
kilosortX.logis not< 4? (IIRC that the logs are formatted in this way)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have asked on KiloSort
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you've run directly from kilosort, you'll have the version in
kilosort4.log. So we could check there, and if it's not there, we have a guess...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great that sounds good, if this tool is only for
kilosort4then just checking for the existing of that log file should do (unless it's extended to other versions)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello, the log files only appeared at
v4.0.33. Thinking of other ways to check...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know of any files which KS2/2.5 defo don't have in their output, that KS4 does?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello, I've added a
_guess_kilosort_versionfunction, to isolate this logic. Let's compare some outputs and see if we can make it reasonable.