Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions driftmapviewer_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,37 +115,46 @@ def _filter_large_amplitude_spikes(
spike_amplitudes: np.ndarray,
spike_depths: np.ndarray,
large_amplitude_only_segment_size,
return_mask: bool = False,
) -> tuple[np.ndarray, ...]:
"""
Return spike properties with only the largest-amplitude spikes included. The probe
is split into segments, and within each segment the mean and std computed.
Any spike less than 1.5x the standard deviation in amplitude of it's segment is excluded
Splitting the probe is only done for the exclusion step, the returned array are flat.
Return spike properties with only the largest-amplitude spikes included.

Takes as input arrays `spike_times`, `spike_depths` and `spike_amplitudes` and returns
copies of these arrays containing only the large amplitude spikes.
If return_mask=True, also returns `spike_bool` (mask into the *input* arrays).
"""
spike_bool = np.zeros_like(spike_amplitudes, dtype=bool)

segment_size_um = large_amplitude_only_segment_size
probe_segments_left_edges = np.arange(np.floor(spike_depths.max() / segment_size_um) + 1) * segment_size_um
probe_segments_left_edges = (
np.arange(np.floor(spike_depths.max() / segment_size_um) + 1) * segment_size_um
)

for segment_left_edge in probe_segments_left_edges:
segment_right_edge = segment_left_edge + segment_size_um

spikes_in_seg = np.where(
np.logical_and(spike_depths >= segment_left_edge, spike_depths < segment_right_edge)
)[0]
if spikes_in_seg.size == 0:
continue

spike_amps_in_seg = spike_amplitudes[spikes_in_seg]
is_high_amplitude = spike_amps_in_seg > np.mean(spike_amps_in_seg) + 1.5 * np.std(spike_amps_in_seg, ddof=1)

spike_bool[spikes_in_seg] = is_high_amplitude
# avoid ddof=1 on tiny segments
if spike_amps_in_seg.size < 3:
spike_bool[spikes_in_seg] = True
continue

spike_times = spike_times[spike_bool]
spike_amplitudes = spike_amplitudes[spike_bool]
spike_depths = spike_depths[spike_bool]
thr = np.mean(spike_amps_in_seg) + 1.5 * np.std(spike_amps_in_seg, ddof=1)
spike_bool[spikes_in_seg] = spike_amps_in_seg > thr

return spike_times, spike_amplitudes, spike_depths
out_times = spike_times[spike_bool]
out_amps = spike_amplitudes[spike_bool]
out_depths = spike_depths[spike_bool]

if return_mask:
return out_times, out_amps, out_depths, spike_bool
return out_times, out_amps, out_depths

def _plot_kilosort_drift_map_raster(
spike_times: np.ndarray,
Expand Down
22 changes: 22 additions & 0 deletions example_run_hover_viewer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Example: run the hover viewer.

Usage:
python example_run_hover_viewer.py /path/to/kilosort/output --ks_version kilosort4 --decimate 20
"""

from __future__ import annotations
import argparse
from hover_viewer_pyqtgraph import run


def main():
ap = argparse.ArgumentParser()
ap.add_argument("sorter_output", type=str)
ap.add_argument("--ks_version", type=str, default="kilosort4", help="kilosort4 or kilosort1_3")
args = ap.parse_args()

run(args.sorter_output, ks_version="kilosort4")


if __name__ == "__main__":
main()
24 changes: 12 additions & 12 deletions helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,29 @@ def load_cluster_groups(cluster_path: Path) -> tuple[np.ndarray, ...]:
return cluster_ids, cluster_groups

# This is such a jankily written function fix it
def exclude_noise(sorter_output, spike_times, spike_amplitudes, spike_depths):
def exclude_noise(sorter_output, spike_times, spike_amplitudes, spike_depths, return_mask: bool = False):
""""""
if (cluster_path := sorter_output / "spike_clusters.npy").is_file(): # TODO: this can be csv?!?!?
if (cluster_path := sorter_output / "spike_clusters.npy").is_file():
spike_clusters = np.load(cluster_path)
else:
# this is a pain to have here, I don't think this case is realistic.
raise NotImplementedError("spike clusters.csv does not exist. Under what circumstance is this? probably very old.")
# spike_clusters = spike_templates.copy()
raise NotImplementedError("spike clusters.csv does not exist.")

if ( # short circuit ensures cluster_path is assigned appropriately
if (
(cluster_path := sorter_output / "cluster_groups.csv").is_file()
or (cluster_path := sorter_output / "cluster_group.tsv").is_file()
):
cluster_ids, cluster_groups = load_cluster_groups(cluster_path)

noise_cluster_ids = cluster_ids[cluster_groups == 0]
not_noise_clusters_by_spike = ~np.isin(spike_clusters.ravel(),
noise_cluster_ids)
spike_times = spike_times[not_noise_clusters_by_spike]
spike_amplitudes = spike_amplitudes[not_noise_clusters_by_spike]
spike_depths = spike_depths[not_noise_clusters_by_spike]
keep = ~np.isin(spike_clusters.ravel(), noise_cluster_ids)

return spike_times, spike_amplitudes, spike_depths
out_times = spike_times[keep]
out_amps = spike_amplitudes[keep]
out_depths = spike_depths[keep]

if return_mask:
return out_times, out_amps, out_depths, keep
return out_times, out_amps, out_depths

raise ValueError(
f"`exclude_noise` is `True` but there is no `cluster_groups.csv` or `.tsv` "
Expand Down
175 changes: 175 additions & 0 deletions hover_viewer_pyqtgraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Interactive drift map with hover-to-template waveform preview (PyQtGraph).

Run this as a script (see example_run_hover_viewer.py).
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
import numpy as np

import pyqtgraph as pg
from pyqtgraph.Qt import QtWidgets, QtCore

import ks_template_loader


@dataclass
class ViewerConfig:
ks_version: str = "kilosort4"
decimate: int | None = 20 # keep every Nth spike for speed
max_points: int | None = None # optional hard cap after decimate
point_size: float = 4.0
hover_radius_px: float = 12.0 # hover "snap" radius


class DriftMapHoverViewer(QtWidgets.QMainWindow):
def __init__(self, sorter_output: str | Path, cfg: ViewerConfig = ViewerConfig()):
super().__init__()
self.setWindowTitle("Drift map hover viewer")

sorter_output = Path(sorter_output)

# Load spikes + templates
spike_times, spike_amps, spike_depths = ks_template_loader.load_spikes(sorter_output, cfg.ks_version)
spike_templates, template_waveforms = ks_template_loader.load_templates(sorter_output, cfg.ks_version)

# Basic validation
n = min(len(spike_times), len(spike_templates), len(spike_depths), len(spike_amps))
spike_times = spike_times[:n]
spike_amps = spike_amps[:n]
spike_depths = spike_depths[:n]
spike_templates = spike_templates[:n]

# Decimate for responsiveness
if cfg.decimate and cfg.decimate > 1:
idx = np.arange(0, n, cfg.decimate, dtype=np.int64)
spike_times = spike_times[idx]
spike_amps = spike_amps[idx]
spike_depths = spike_depths[idx]
spike_templates = spike_templates[idx]

if cfg.max_points is not None and len(spike_times) > cfg.max_points:
spike_times = spike_times[:cfg.max_points]
spike_amps = spike_amps[:cfg.max_points]
spike_depths = spike_depths[:cfg.max_points]
spike_templates = spike_templates[:cfg.max_points]

self._spike_times = spike_times
self._spike_depths = spike_depths
self._spike_templates = spike_templates
self._template_waveforms = template_waveforms

# Central widget with two plots
cw = QtWidgets.QWidget()
self.setCentralWidget(cw)
layout = QtWidgets.QHBoxLayout(cw)

# Left: scatter
self.scatter_plot = pg.PlotWidget(title="Drift map (hover a point)")
self.scatter_plot.setLabel("bottom", "time (s)")
self.scatter_plot.setLabel("left", "depth (um)")
self.scatter_plot.showGrid(x=True, y=True, alpha=0.2)
layout.addWidget(self.scatter_plot, stretch=3)

# Right: waveform
self.wave_plot = pg.PlotWidget(title="Template waveform")
self.wave_plot.setLabel("bottom", "sample")
self.wave_plot.setLabel("left", "a.u.")
self.wave_plot.showGrid(x=True, y=True, alpha=0.2)
layout.addWidget(self.wave_plot, stretch=2)

self.wave_curve = self.wave_plot.plot([])

# Scatter item
self.scatter_item = pg.ScatterPlotItem(pxMode=True, size=cfg.point_size, hoverable=True)
self.scatter_plot.addItem(self.scatter_item)

# Attach per-point metadata via "data"
spots = [
{"pos": (float(t), float(d)), "data": int(tid)}
for t, d, tid in zip(self._spike_times, self._spike_depths, self._spike_templates)
]
self.scatter_item.addPoints(spots)

# Hover handling: use a signal proxy on scene mouse move, then find nearest point
self._cfg = cfg
self._last_tid = None

self._proxy = pg.SignalProxy(
self.scatter_plot.scene().sigMouseMoved,
rateLimit=60,
slot=self._on_mouse_moved,
)


def _on_mouse_moved(self, evt):
pos = evt[0] # QPointF in scene coords
vb = self.scatter_plot.getViewBox()
if not vb.sceneBoundingRect().contains(pos):
return

mouse_point = vb.mapSceneToView(pos)
mx, my = float(mouse_point.x()), float(mouse_point.y())

# Find candidate spikes in a time/depth window (fast coarse filter)
# Window size is based on the current view width/height.
xr = vb.viewRange()[0]
yr = vb.viewRange()[1]
xw = (xr[1] - xr[0]) * 0.02
yw = (yr[1] - yr[0]) * 0.02

x0, x1 = mx - xw, mx + xw
y0, y1 = my - yw, my + yw

mask = (self._spike_times >= x0) & (self._spike_times <= x1) & (self._spike_depths >= y0) & (self._spike_depths <= y1)
if not np.any(mask):
return

# Among candidates, pick the nearest in screen pixels
cand_idx = np.where(mask)[0]
if cand_idx.size == 0:
return

# Map candidates to scene pixels to compute distance in px
pts = np.vstack([self._spike_times[cand_idx], self._spike_depths[cand_idx]]).T
scene_pts = np.array([vb.mapViewToScene(pg.Point(p[0], p[1])) for p in pts], dtype=object)

dx = np.array([float(p.x()) - float(pos.x()) for p in scene_pts])
dy = np.array([float(p.y()) - float(pos.y()) for p in scene_pts])
dist2 = dx * dx + dy * dy
k = int(np.argmin(dist2))

if dist2[k] > (self._cfg.hover_radius_px ** 2):
return

tid = int(self._spike_templates[cand_idx[k]])
if tid == self._last_tid:
return

self._last_tid = tid
wf = self._template_waveforms[tid]
self.wave_curve.setData(np.arange(wf.size), wf)
self.wave_plot.setTitle(f"Template waveform (id={tid})")


def run(sorter_output: str | Path, ks_version: str = "kilosort4", decimate: int | None = 20):
app = QtWidgets.QApplication.instance() or QtWidgets.QApplication([])
cfg = ViewerConfig(ks_version=ks_version, decimate=decimate)
w = DriftMapHoverViewer(sorter_output, cfg)
w.resize(1300, 750)
w.show()
app.exec()


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Drift map hover-to-template viewer (PyQtGraph)")
parser.add_argument("sorter_output", type=str, help="Path to Kilosort output folder")
parser.add_argument("--ks_version", type=str, default="kilosort4", help="kilosort4 or kilosort1_3")
parser.add_argument("--decimate", type=int, default=20, help="Keep every Nth spike")
args = parser.parse_args()

run(args.sorter_output, ks_version=args.ks_version, decimate=args.decimate)
5 changes: 4 additions & 1 deletion kilosort_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,7 @@ def get_spikes_info_ks4(
spike_amplitudes = np.load(sorter_output / "amplitudes.npy")
spike_depths = np.load(sorter_output / "spike_positions.npy")[:, 1]

return spike_times, spike_amplitudes, spike_depths
spike_templates = np.load(sorter_output / "spike_templates.npy") # rename spike_tempaltes_idx?
templates = np.load(sorter_output / "templates.npy") # rename unwihten?

return spike_times, spike_amplitudes, spike_depths, spike_templates, templates
Loading