Skip to content

Commit b4f613c

Browse files
committed
iblapps: viewspikes initial commit
1 parent 8222e49 commit b4f613c

File tree

6 files changed

+470
-0
lines changed

6 files changed

+470
-0
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
easyqc
12
ibllib
23
pyqtgraph
34
simpleITK

viewspikes/data.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from pathlib import Path
2+
import shutil
3+
4+
from ibllib.io import spikeglx
5+
from oneibl.webclient import dataset_record_to_url
6+
7+
import numpy as np
8+
import scipy.signal
9+
10+
import alf.io
11+
12+
13+
CHUNK_DURATION_SECS = 1
14+
OUTPUT_TO_TEST = True
15+
16+
17+
def get_ks2_batch(ks2memmap, ibatch):
18+
BATCH_SIZE = 65600
19+
NTR = 384
20+
offset = BATCH_SIZE * NTR * ibatch
21+
from_to = np.array([0, BATCH_SIZE * NTR])
22+
slic = slice(from_to[0] + offset, from_to[1] + offset)
23+
24+
ks2 = np.reshape(ks2memmap[slice(from_to[0] + offset, from_to[1] + offset)], (NTR, BATCH_SIZE))
25+
return ks2
26+
27+
28+
# ks2 proc
29+
def get_ks2(raw, dsets, one):
30+
kwm = next(dset for dset in dsets if dset['dataset_type'] == 'kilosort.whitening_matrix')
31+
kwm = np.load(one.download_dataset(kwm))
32+
channels = [dset for dset in dsets if dset['dataset_type'].startswith('channels')]
33+
malf_path = next(iter(one.download_datasets(channels))).parent
34+
channels = alf.io.load_object(malf_path, 'channels')
35+
_car = raw[channels['rawInd'], :] - np.mean(raw[channels.rawInd, :], axis=0)
36+
sos = scipy.signal.butter(3, 300 / 30000 / 2, btype='highpass', output='sos')
37+
ks2 = np.zeros_like(raw)
38+
ks2[channels['rawInd'], :] = scipy.signal.sosfiltfilt(sos, _car)
39+
std_carbutt = np.std(ks2)
40+
ks2[channels['rawInd'], :] = np.matmul(kwm, ks2[channels['rawInd'], :])
41+
ks2 = ks2 * std_carbutt / np.std(ks2)
42+
return ks2
43+
44+
45+
def get_spikes(dsets, one):
46+
dtypes_spikes = ['spikes.clusters', 'spikes.amps', 'spikes.times', 'clusters.channels', 'spikes.samples']
47+
dsets_spikes = [dset for dset in dsets if dset['dataset_type'] in dtypes_spikes]
48+
malf_path = next(iter(one.download_datasets(dsets_spikes))).parent
49+
channels = alf.io.load_object(malf_path, 'channels')
50+
clusters = alf.io.load_object(malf_path, 'clusters')
51+
spikes = alf.io.load_object(malf_path, 'spikes')
52+
return spikes, clusters, channels
53+
54+
55+
def stream(pid, t0, one=None, cache=True, dsets=None):
56+
tlen = 1
57+
assert one
58+
if cache:
59+
samples_folder = Path(one._par.CACHE_DIR).joinpath('cache', 'ap')
60+
sample_file_name = Path(f"{pid}_{str(int(t0)).zfill(5)}.meta")
61+
if dsets is None:
62+
dsets = one.alyx.rest('datasets', 'list', probe_insertion=pid)
63+
if cache and samples_folder.joinpath(sample_file_name).exists():
64+
print(f'loading {sample_file_name} from cache')
65+
sr = spikeglx.Reader(samples_folder.joinpath(sample_file_name).with_suffix('.bin'))
66+
return sr, dsets
67+
68+
dset_ch = next(dset for dset in dsets if dset['dataset_type'] == "ephysData.raw.ch" and '.ap.' in dset['name'])
69+
dset_meta = next(dset for dset in dsets if dset['dataset_type'] == "ephysData.raw.meta" and '.ap.' in dset['name'])
70+
dset_cbin = next(dset for dset in dsets if dset['dataset_type'] == "ephysData.raw.ap" and '.ap.' in dset['name'])
71+
72+
file_ch, file_meta = one.download_datasets([dset_ch, dset_meta])
73+
74+
first_chunk = int(t0 / CHUNK_DURATION_SECS)
75+
last_chunk = int((t0 + tlen) / CHUNK_DURATION_SECS) - 1
76+
77+
sr = one.download_raw_partial(
78+
url_cbin=dataset_record_to_url(dset_cbin)[0],
79+
url_ch=file_ch,
80+
first_chunk=first_chunk,
81+
last_chunk=last_chunk)
82+
83+
if cache:
84+
out_meta = samples_folder.joinpath(sample_file_name)
85+
shutil.copy(sr.file_meta_data, out_meta)
86+
with open(out_meta.with_suffix('.bin'), 'wb') as fp:
87+
sr._raw[:].tofile(fp)
88+
89+
return sr, dsets

viewspikes/datoviz.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from dataclasses import dataclass
2+
3+
import numpy as np
4+
5+
import datoviz as dviz
6+
7+
# -------------------------------------------------------------------------------------------------
8+
# Raster viewer
9+
# -------------------------------------------------------------------------------------------------
10+
11+
12+
class RasterView:
13+
def __init__(self):
14+
self.canvas = dviz.canvas(show_fps=True)
15+
self.panel = self.canvas.panel(controller='axes')
16+
self.visual = self.panel.visual('point')
17+
self.pvars = {'ms': 2., 'alpha': .03}
18+
self.gui = self.canvas.gui('XY')
19+
self.gui.control("label", "Coords", value="(0, 0)")
20+
21+
def set_spikes(self, spikes):
22+
pos = np.c_[spikes.times, spikes.depths, np.zeros_like(spikes.times)]
23+
color = dviz.colormap(20 * np.log10(spikes.amps), cmap='cividis', alpha=self.pvars['alpha'])
24+
self.visual.data('pos', pos)
25+
self.visual.data('color', color)
26+
self.visual.data('ms', np.array([self.pvars['ms']]))
27+
28+
29+
class RasterController:
30+
_time_select_cb = None
31+
32+
def __init__(self, model, view):
33+
self.m = model
34+
self.v = view
35+
self.v.canvas.connect(self.on_mouse_move)
36+
self.v.canvas.connect(self.on_key_press)
37+
self.redraw()
38+
39+
def redraw(self):
40+
print('redraw', self.v.pvars)
41+
self.v.set_spikes(self.m.spikes)
42+
43+
def on_mouse_move(self, x, y, modifiers=()):
44+
p = self.v.canvas.panel_at(x, y)
45+
if not p:
46+
return
47+
# Then, we transform into the data coordinate system
48+
# Supported coordinate systems:
49+
# target_cds='data' / 'scene' / 'vulkan' / 'framebuffer' / 'window'
50+
xd, yd = p.pick(x, y)
51+
self.v.gui.set_value("Coords", f"({xd:0.2f}, {yd:0.2f})")
52+
53+
def on_key_press(self, key, modifiers=()):
54+
print(key, modifiers)
55+
if key == 'a' and modifiers == ('control',):
56+
self.v.pvars['alpha'] = np.minimum(self.v.pvars['alpha'] + 0.1, 1.)
57+
elif key == 'z' and modifiers == ('control',):
58+
self.v.pvars['alpha'] = np.maximum(self.v.pvars['alpha'] - 0.1, 0.)
59+
elif key == 'page_up':
60+
self.v.pvars['ms'] = np.minimum(self.v.pvars['ms'] * 1.1, 20)
61+
elif key == 'page_down':
62+
self.v.pvars['ms'] = np.maximum(self.v.pvars['ms'] / 1.1, 1)
63+
else:
64+
return
65+
self.redraw()
66+
67+
68+
@dataclass
69+
class RasterModel:
70+
spikes: dict
71+
72+
73+
def raster(spikes):
74+
rm = RasterController(RasterModel(spikes), RasterView())
75+
dviz.run()

viewspikes/examples.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import numpy as np
2+
import scipy
3+
from easyqc.gui import viewseis
4+
5+
from oneibl.one import ONE
6+
from ibllib.ephys import neuropixel
7+
from ibllib.dsp import voltage
8+
9+
from iblapps.viewspikes.data import stream, get_ks2, get_spikes
10+
from iblapps.viewspikes.plots import plot_insertion, show_psd, overlay_spikes
11+
12+
one = ONE()
13+
14+
## Example 1: Stream one second of ephys data
15+
pid, t0 = ('8413c5c6-b42b-4ec6-b751-881a54413628', 610)
16+
sr, dsets = stream(pid, t0=t0, one=one, cache=True)
17+
18+
## Example 2: Plot Insertion for a given PID (todo: use Needles 2 for interactive)
19+
plot_insertion(pid, one=one)
20+
21+
## Example 3: High-pass the data and show the PSD
22+
raw = sr[:, :-1].T
23+
show_psd(raw, sr.fs)
24+
25+
## Example 4: Display the raw / pre-proc and KS2 parts
26+
h = neuropixel.trace_header()
27+
sos = scipy.signal.butter(3, 300 / sr.fs / 2, btype='highpass', output='sos')
28+
butt = scipy.signal.sosfiltfilt(sos, raw)
29+
fk_kwargs ={'dx': 1, 'vbounds': [0, 1e6], 'ntr_pad': 160, 'ntr_tap': 0, 'lagc': .01, 'btype': 'lowpass'}
30+
destripe = voltage.destripe(raw, fs=sr.fs, fk_kwargs=fk_kwargs, tr_sel=np.arange(raw.shape[0]))
31+
ks2 = get_ks2(raw, dsets, one)
32+
eqc_butt = viewseis(butt.T, si=1 / sr.fs, h=h, t0=t0, title='butt', taxis=0)
33+
eqc_dest = viewseis(destripe.T, si=1 / sr.fs, h=h, t0=t0, title='destr', taxis=0)
34+
eqc_ks2 = viewseis(ks2.T, si=1 / sr.fs, h=h, t0=t0, title='ks2', taxis=0)
35+
36+
# Example 5: overlay the spikes on the existing easyqc instances
37+
spikes, clusters, channels = get_spikes(dsets, one)
38+
overlay_spikes(eqc_butt, spikes, clusters, channels)
39+
overlay_spikes(eqc_dest, spikes, clusters, channels)
40+
overlay_spikes(eqc_ks2, spikes, clusters, channels)
41+
42+
# hhh = {k: np.tile(h[k], 3) for k in h}
43+
# eqc_concat = viewseis(np.r_[butt, destripe, ks2], si=1 / sr.fs, h=hhh, t0=t0, title='concat')
44+
# overlay_spikes(eqc_concat, spikes, clusters, channels)

viewspikes/main.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from pathlib import Path
2+
3+
import scipy.signal
4+
import numpy as np
5+
6+
from ibllib.io import spikeglx
7+
from ibllib.dsp import voltage
8+
from ibllib.ephys import neuropixel
9+
from oneibl.one import ONE
10+
from easyqc.gui import viewseis
11+
12+
from viewspikes.plots import plot_insertion, show_psd, overlay_spikes
13+
from viewspikes.data import stream, get_spikes, get_ks2
14+
15+
folder_samples = Path('/datadisk/Data/spike_sorting/short_samples')
16+
files_samples = list(folder_samples.rglob('*.bin'))
17+
18+
one = ONE()
19+
SIDE_BY_SIDE = False
20+
#
21+
# pins = one.alyx.rest('insertions', 'list', django=('json__extended_qc__alignment_count__gt,0'))
22+
# pid, t0 = ('3e7618b8-34ca-4e48-ba3a-0e0f88a43131', 1002) # SWC_054_2020-10-10_probe01__ - sync w/ spikes !!!
23+
# pid, t0 = ('04c9890f-2276-4c20-854f-305ff5c9b6cf', 1002.) # SWC_054_2020-10-10_probe00__04c9890f-2276-4c20-854f-305ff5c9b6cf - sync w/ spikes !!!
24+
# pid, t0 = ('0925fb1b-cf83-4f55-bfb7-aa52f993a404', 500.) # DY_013_2020-03-06_probe00__0925fb1b-cf83-4f55-bfb7-aa52f993a404
25+
# pid, t0 = ('0ece5c6a-7d1e-4365-893d-ac1cc04f1d7b', 750.) # CSHL045_2020-02-27_probe01__0ece5c6a-7d1e-4365-893d-ac1cc04f1d7b
26+
# pid, t0 = ('0ece5c6a-7d1e-4365-893d-ac1cc04f1d7b', 3000.) # CSHL045_2020-02-27_probe01__0ece5c6a-7d1e-4365-893d-ac1cc04f1d7b
27+
pid, t0 = ('10ef1dcd-093c-4839-8f38-90a25edefb49', 2400.)
28+
# pid, t0 = ('1a6a17cc-ba8c-4d79-bf20-cc897c9500dc', 5000)
29+
# pid, t0 = ('2dd99c91-292f-44e3-bbf2-8cfa56015106', 2500) # NYU-23_2020-10-14_probe01__2dd99c91-292f-44e3-bbf2-8cfa56015106
30+
# pid, t0 = ('2dd99c91-292f-44e3-bbf2-8cfa56015106', 6000) # NYU-23_2020-10-14_probe01__2dd99c91-292f-44e3-bbf2-8cfa56015106
31+
# pid, t0 = ('30dfb8c6-9202-43fd-a92d-19fe68602b6f', 2400.) # ibl_witten_27_2021-01-16_probe00__30dfb8c6-9202-43fd-a92d-19fe68602b6f
32+
# pid, t0 = ('31dd223c-0c7c-48b5-a513-41feb4000133', 3000.) # really good one : striping on not all channels
33+
# pid, t0 = ('39b433d0-ec60-460f-8002-a393d81620a4', 2700.) # ZFM-01577_2020-10-27_probe01 needs FDNAT
34+
# pid, t0 = ('47da98a8-f282-4830-92c2-af0e1d4f00e2', 2700.)
35+
36+
# 67 frequency spike
37+
# 458 /datadisk/Data/spike_sorting/short_samples/b45c8f3f-6361-41df-9bc1-9df98b3d30e6_01210.bin ERROR dans le chargement de la whitening matrix
38+
# 433 /datadisk/Data/spike_sorting/short_samples/8d59da25-3a9c-44be-8b1a-e27cdd39ca34_04210.bin Cortex complètement silencieux.
39+
# 531 /datadisk/Data/spike_sorting/short_samples/47be9ae4-290f-46ab-b047-952bc3a1a509_00010.bin Sympa pour le spike sorting, un bon example de trace pourrie à enlever avec FDNAT / Cadzow. Il y a du striping à la fin mais pas de souci pour KS2 ou pour le FK.
40+
# 618 5b9ce60c-dcc9-4789-b2ff-29d873829fa5_03610.bin: gros cabossage plat laissé par le FK !! Tester un filtre K tout bête # spikes tous petits en comparaison. Le spike sorting a l'air décalé
41+
# 681 /datadisk/Data/spike_sorting/short_samples/eab93ab0-26e3-4bd9-9c53-9f81c35172f4_02410.bin !! Spikes décalés. Superbe example de layering dans le cerveau avec 3 niveaux très clairement définis
42+
# 739 /datadisk/Data/spike_sorting/short_samples/f03b61b4-6b13-479d-940f-d1608eb275cc_04210.bin: Autre example de layering ou les charactéristiques spectrales / spatiales sont très différentes. Spikes alignés
43+
# 830 /datadisk/Data/spike_sorting/short_samples/b02c0ce6-2436-4fc0-9ea0-e7083a387d7e_03010.bin, très mauvaise qualité - spikes sont décalés ?!?
44+
45+
46+
47+
file_ind = np.random.randint(len(files_samples))
48+
file_ind = 739 # very good quality spike sorting
49+
print(file_ind, files_samples[file_ind])
50+
51+
pid, t0 = ('47da98a8-f282-4830-92c2-af0e1d4f00e2', 1425.)
52+
53+
pid = files_samples[file_ind]
54+
# pid, t0 = ("01c6065e-eb3c-49ba-9c25-c1f17b18d529", 500)
55+
if isinstance(pid, Path):
56+
file_sample = pid
57+
pid, t0 = file_sample.stem.split('_')
58+
t0 = float(t0)
59+
sr = spikeglx.Reader(file_sample)
60+
dsets = one.alyx.rest('datasets', 'list', probe_insertion=pid)
61+
else:
62+
sr, dsets = stream(pid, t0, one=one, samples_folder=folder_samples)
63+
64+
#
65+
plot_insertion(pid, one)
66+
67+
68+
h = neuropixel.trace_header()
69+
raw = sr[:, :-1].T
70+
71+
sos = scipy.signal.butter(3, 300 / sr.fs / 2, btype='highpass', output='sos')
72+
butt = scipy.signal.sosfiltfilt(sos, raw)
73+
# show_psd(butt, sr.fs)
74+
75+
fk_kwargs ={'dx': 1, 'vbounds': [0, 1e6], 'ntr_pad': 160, 'ntr_tap': 0, 'lagc': .01, 'btype': 'lowpass'}
76+
destripe = voltage.destripe(raw, fs=sr.fs, fk_kwargs=fk_kwargs, tr_sel=np.arange(raw.shape[0]))
77+
ks2 = get_ks2(raw, dsets, one)
78+
79+
# get the spikes corresponding to current chunk, here needs to go through samples for sync reasons
80+
spikes, clusters, channels = get_spikes(dsets, one)
81+
82+
if SIDE_BY_SIDE:
83+
hhh = {k: np.tile(h[k], 3) for k in h}
84+
eqc_concat = viewseis(np.r_[butt, destripe, ks2], si=1 / sr.fs, h=hhh, t0=t0, title='concat')
85+
overlay_spikes(eqc_concat, spikes, clusters, channels)
86+
else:
87+
eqc_butt = viewseis(butt.T, si=1 / sr.fs, h=h, t0=t0, title='butt', taxis=0)
88+
eqc_dest = viewseis(destripe.T, si=1 / sr.fs, h=h, t0=t0, title='destr', taxis=0)
89+
eqc_ks2 = viewseis(ks2.T, si=1 / sr.fs, h=h, t0=t0, title='ks2', taxis=0)
90+
overlay_spikes(eqc_butt, spikes, clusters, channels)
91+
overlay_spikes(eqc_dest, spikes, clusters, channels)
92+
overlay_spikes(eqc_ks2, spikes, clusters, channels)

0 commit comments

Comments
 (0)