Skip to content

Commit 0cfb83a

Browse files
chris-langfieldolicheGaelleChapuischris-langfield
authored
Spike sorting rerun (#755)
* add wf extraction to SpikeSorting task * ibl-neuropixel version * add the sync dataset in the spike sorting loader probe info * add keys * add to metrics dict * bitwise qc label * reverse the bitwise qc labels so 0 is always passin * adjust wf extraction call * rename bitwise qc key * fix test metrics * check for pyks import * add the cell qc computation at the spike sorting stage * PopeyeDataHandler * add wf extraction to SpikeSorting task * ibl-neuropixel version * add the sync dataset in the spike sorting loader probe info * add keys * add to metrics dict * bitwise qc label * reverse the bitwise qc labels so 0 is always passin * adjust wf extraction call * rename bitwise qc key * fix test metrics * check for pyks import * add the cell qc computation at the spike sorting stage * PopeyeDataHandler * update spike sorting plots * pass when symlink exists * revert patcher path changes * patcher patch * sdsc spikesorting registration * bugfix: SpikeSortingLoader.raw_electrophysiology regex to match cbin files on SDSC * random tempdir for pykilosort * fix ss reg task * final fix RegisterSpikeSorting * add low memory option to SS task * Popeye patcher: allows overriding SDSC_PATCH_PATH with env variable * popeye data handler Path bugfix * revision in sdsc registration * update revision in constructor * SDSC DataHandler get patch path from env * Revert "SDSC DataHandler get patch path from env" This reverts commit eeea95c. * pin slidingRP requirement - we can unpin when merging with newer ibllib * slidingRP * fix slidingRP bug * spike sorting if pykilosort is available, do not run subprocess * flake * revert regex for pyks find version * task data handler allows empty input dataset list * changes to spike sorting loader * spike sorting loader gets a good_units parameters --------- Co-authored-by: owinter <[email protected]> Co-authored-by: Gaelle <[email protected]> Co-authored-by: chris-langfield <[email protected]>
1 parent ac3f36e commit 0cfb83a

File tree

10 files changed

+397
-205
lines changed

10 files changed

+397
-205
lines changed

brainbox/io/one.py

Lines changed: 100 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from dataclasses import dataclass, field
33
import gc
44
import logging
5+
import re
56
import os
67
from pathlib import Path
78

8-
99
import numpy as np
1010
import pandas as pd
1111
from scipy.interpolate import interp1d
@@ -19,13 +19,14 @@
1919
from neuropixel import TIP_SIZE_UM, trace_header
2020
import spikeglx
2121

22+
import ibldsp.voltage
2223
from iblutil.util import Bunch
23-
from ibllib.io.extractors.training_wheel import extract_wheel_moves, extract_first_movement_times
2424
from iblatlas.atlas import AllenAtlas, BrainRegions
2525
from iblatlas import atlas
26+
from ibllib.io.extractors.training_wheel import extract_wheel_moves, extract_first_movement_times
2627
from ibllib.pipes import histology
2728
from ibllib.pipes.ephys_alignment import EphysAlignment
28-
from ibllib.plots import vertical_lines
29+
from ibllib.plots import vertical_lines, Density
2930

3031
import brainbox.plot
3132
from brainbox.io.spikeglx import Streamer
@@ -916,16 +917,18 @@ def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_
916917
if missing == 'raise':
917918
raise e
918919

919-
def download_spike_sorting(self, **kwargs):
920+
def download_spike_sorting(self, objects=None, **kwargs):
920921
"""
921922
Downloads spikes, clusters and channels
922923
:param spike_sorter: (defaults to 'pykilosort')
923924
:param dataset_types: list of extra dataset types
925+
:param objects: list of objects to download, defaults to ['spikes', 'clusters', 'channels']
924926
:return:
925927
"""
926-
for obj in ['spikes', 'clusters', 'channels']:
928+
objects = ['spikes', 'clusters', 'channels'] if objects is None else objects
929+
for obj in objects:
927930
self.download_spike_sorting_object(obj=obj, **kwargs)
928-
self.spike_sorting_path = self.files['spikes'][0].parent
931+
self.spike_sorting_path = self.files['clusters'][0].parent
929932

930933
def download_raw_electrophysiology(self, band='ap'):
931934
"""
@@ -963,7 +966,7 @@ def raw_electrophysiology(self, stream=True, band='ap', **kwargs):
963966
return Streamer(pid=self.pid, one=self.one, typ=band, **kwargs)
964967
else:
965968
raw_data_files = self.download_raw_electrophysiology(band=band)
966-
cbin_file = next(filter(lambda f: f.name.endswith(f'.{band}.cbin'), raw_data_files), None)
969+
cbin_file = next(filter(lambda f: re.match(rf".*\.{band}\..*cbin", f.name), raw_data_files), None)
967970
if cbin_file is not None:
968971
return spikeglx.Reader(cbin_file)
969972

@@ -999,7 +1002,7 @@ def load_channels(self, **kwargs):
9991002
self.histology = 'alf'
10001003
return channels
10011004

1002-
def load_spike_sorting(self, spike_sorter='pykilosort', **kwargs):
1005+
def load_spike_sorting(self, spike_sorter='pykilosort', revision=None, enforce_version=True, good_units=False, **kwargs):
10031006
"""
10041007
Loads spikes, clusters and channels
10051008
@@ -1013,20 +1016,44 @@ def load_spike_sorting(self, spike_sorter='pykilosort', **kwargs):
10131016
- traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data
10141017
10151018
:param spike_sorter: (defaults to 'pykilosort')
1016-
:param dataset_types: list of extra dataset types
1019+
:param revision: for example "2024-05-06", (defaults to None):
1020+
:param enforce_version: if True, will raise an error if the spike sorting version and revision is not the expected one
1021+
:param dataset_types: list of extra dataset types, for example: ['spikes.samples', 'spikes.templates']
1022+
:param good_units: False, if True will load only the good units, possibly by downloading a smaller spikes table
1023+
:param kwargs: additional arguments to be passed to one.api.One.load_object
10171024
:return:
10181025
"""
10191026
if len(self.collections) == 0:
10201027
return {}, {}, {}
10211028
self.files = {}
10221029
self.spike_sorter = spike_sorter
1023-
self.download_spike_sorting(spike_sorter=spike_sorter, **kwargs)
1024-
channels = self.load_channels(spike_sorter=spike_sorter, **kwargs)
1030+
self.revision = revision
1031+
objects = ['passingSpikes', 'clusters', 'channels'] if good_units else None
1032+
self.download_spike_sorting(spike_sorter=spike_sorter, revision=revision, objects=objects, **kwargs)
1033+
channels = self.load_channels(spike_sorter=spike_sorter, revision=revision, **kwargs)
10251034
clusters = self._load_object(self.files['clusters'], wildcards=self.one.wildcards)
1026-
spikes = self._load_object(self.files['spikes'], wildcards=self.one.wildcards)
1027-
1035+
if good_units:
1036+
spikes = self._load_object(self.files['passingSpikes'], wildcards=self.one.wildcards)
1037+
else:
1038+
spikes = self._load_object(self.files['spikes'], wildcards=self.one.wildcards)
1039+
if enforce_version:
1040+
self._assert_version_consistency()
10281041
return spikes, clusters, channels
10291042

1043+
def _assert_version_consistency(self):
1044+
"""
1045+
Makes sure the state of the spike sorting object matches the files downloaded
1046+
:return: None
1047+
"""
1048+
for k in ['spikes', 'clusters', 'channels', 'passingSpikes']:
1049+
for fn in self.files.get(k, []):
1050+
if self.spike_sorter:
1051+
assert fn.relative_to(self.session_path).parts[2] == self.spike_sorter, \
1052+
f"You required strict version {self.spike_sorter}, {fn} does not match"
1053+
if self.revision:
1054+
assert fn.relative_to(self.session_path).parts[3] == f"#{self.revision}#", \
1055+
f"You required strict revision {self.revision}, {fn} does not match"
1056+
10301057
@staticmethod
10311058
def compute_metrics(spikes, clusters=None):
10321059
nc = clusters['channels'].size if clusters else np.unique(spikes['clusters']).size
@@ -1079,6 +1106,8 @@ def _get_probe_info(self):
10791106
if self._sync is None:
10801107
timestamps = self.one.load_dataset(
10811108
self.eid, dataset='_spikeglx_*.timestamps.npy', collection=f'raw_ephys_data/{self.pname}')
1109+
_ = self.one.load_dataset( # this is not used here but we want to trigger the download for potential tasks
1110+
self.eid, dataset='_spikeglx_*.sync.npy', collection=f'raw_ephys_data/{self.pname}')
10821111
try:
10831112
ap_meta = spikeglx.read_meta_data(self.one.load_dataset(
10841113
self.eid, dataset='_spikeglx_*.ap.meta', collection=f'raw_ephys_data/{self.pname}'))
@@ -1116,7 +1145,13 @@ def samples2times(self, values, direction='forward'):
11161145
def pid2ref(self):
11171146
return f"{self.one.eid2ref(self.eid, as_dict=False)}_{self.pname}"
11181147

1119-
def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_series=None, **kwargs):
1148+
def _default_plot_title(self, spikes):
1149+
title = f"{self.pid2ref}, {self.pid} \n" \
1150+
f"{spikes['clusters'].size:_} spikes, {np.unique(spikes['clusters']).size:_} clusters"
1151+
return title
1152+
1153+
def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_series=None,
1154+
drift=None, title=None, **kwargs):
11201155
"""
11211156
:param spikes: spikes dictionary or Bunch
11221157
:param channels: channels dictionary or Bunch.
@@ -1138,9 +1173,9 @@ def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_
11381173
# set default raster plot parameters
11391174
kwargs = {"t_bin": 0.007, "d_bin": 10, "vmax": 0.5}
11401175
brainbox.plot.driftmap(spikes['times'], spikes['depths'], ax=axs[1, 0], **kwargs)
1141-
title_str = f"{self.pid2ref}, {self.pid} \n" \
1142-
f"{spikes['clusters'].size:_} spikes, {np.unique(spikes['clusters']).size:_} clusters"
1143-
axs[0, 0].title.set_text(title_str)
1176+
if title is None:
1177+
title = self._default_plot_title(spikes)
1178+
axs[0, 0].title.set_text(title)
11441179
for k, ts in time_series.items():
11451180
vertical_lines(ts, ymin=0, ymax=3800, ax=axs[1, 0])
11461181
if 'atlas_id' in channels:
@@ -1150,10 +1185,55 @@ def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_
11501185
axs[1, 0].set_xlim(spikes['times'][0], spikes['times'][-1])
11511186
fig.tight_layout()
11521187

1153-
self.download_spike_sorting_object('drift', self.spike_sorter, missing='ignore')
1154-
if 'drift' in self.files:
1155-
drift = self._load_object(self.files['drift'], wildcards=self.one.wildcards)
1188+
if drift is None:
1189+
self.download_spike_sorting_object('drift', self.spike_sorter, missing='ignore')
1190+
if 'drift' in self.files:
1191+
drift = self._load_object(self.files['drift'], wildcards=self.one.wildcards)
1192+
if isinstance(drift, dict):
11561193
axs[0, 0].plot(drift['times'], drift['um'], 'k', alpha=.5)
1194+
axs[0, 0].set(ylim=[-15, 15])
1195+
1196+
if save_dir is not None:
1197+
png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir)
1198+
fig.savefig(png_file)
1199+
plt.close(fig)
1200+
gc.collect()
1201+
else:
1202+
return fig, axs
1203+
1204+
def plot_rawdata_snippet(self, sr, spikes, clusters, t0,
1205+
channels=None,
1206+
br: BrainRegions = None,
1207+
save_dir=None,
1208+
label='raster',
1209+
gain=-93,
1210+
title=None):
1211+
1212+
# compute the raw data offset and destripe, we take 400ms around t0
1213+
first_sample, last_sample = (int((t0 - 0.2) * sr.fs), int((t0 + 0.2) * sr.fs))
1214+
raw = sr[first_sample:last_sample, :-sr.nsync].T
1215+
channel_labels = channels['labels'] if (channels is not None) and ('labels' in channels) else True
1216+
destriped = ibldsp.voltage.destripe(raw, sr.fs, channel_labels=channel_labels)
1217+
# filter out the spikes according to good/bad clusters and to the time slice
1218+
spike_sel = slice(*np.searchsorted(spikes['samples'], [first_sample, last_sample]))
1219+
ss = spikes['samples'][spike_sel]
1220+
sc = clusters['channels'][spikes['clusters'][spike_sel]]
1221+
sok = clusters['label'][spikes['clusters'][spike_sel]] == 1
1222+
if title is None:
1223+
title = self._default_plot_title(spikes)
1224+
# display the raw data snippet with spikes overlaid
1225+
fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9), sharex='col')
1226+
Density(destriped, fs=sr.fs, taxis=1, gain=gain, ax=axs[0], t0=t0 - 0.2, unit='s')
1227+
axs[0].scatter(ss[sok] / sr.fs, sc[sok], color="green", alpha=0.5)
1228+
axs[0].scatter(ss[~sok] / sr.fs, sc[~sok], color="red", alpha=0.5)
1229+
axs[0].set(title=title, xlim=[t0 - 0.035, t0 + 0.035])
1230+
# adds the channel locations if available
1231+
if (channels is not None) and ('atlas_id' in channels):
1232+
br = br or BrainRegions()
1233+
plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'],
1234+
brain_regions=br, display=True, ax=axs[1], title=self.histology)
1235+
axs[1].get_yaxis().set_visible(False)
1236+
fig.tight_layout()
11571237

11581238
if save_dir is not None:
11591239
png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir)

brainbox/metrics/single_units.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,12 @@
4242
'missed_spikes_est': dict(spks_per_bin=10, sigma=4, min_num_bins=50),
4343
'acceptable_contamination': 0.1,
4444
'bin_size': 0.25,
45-
'med_amp_thresh_uv': 50,
45+
'med_amp_thresh_uv': 50, # units below this threshold are considered noise
4646
'min_isi': 0.0001,
4747
'presence_window': 10,
4848
'refractory_period': 0.0015,
4949
'RPslide_thresh': 0.1,
50+
'RPmax_confidence': 90, # a unit needs to pass with at least this confidence percentage (0 - 100)
5051
}
5152

5253

@@ -942,7 +943,11 @@ def quick_unit_metrics(spike_clusters, spike_times, spike_amps, spike_depths,
942943
'presence_ratio',
943944
'presence_ratio_std',
944945
'slidingRP_viol',
945-
'spike_count'
946+
'spike_count',
947+
'slidingRP_viol_forced',
948+
'max_confidence',
949+
'min_contamination',
950+
'n_spikes_below2'
946951
]
947952
if tbounds:
948953
ispi = between_sorted(spike_times, tbounds)
@@ -982,6 +987,10 @@ def quick_unit_metrics(spike_clusters, spike_times, spike_amps, spike_depths,
982987
srp = metrics.slidingRP_all(spikeTimes=spike_times, spikeClusters=spike_clusters,
983988
sampleRate=30000, binSizeCorr=1 / 30000)
984989
r.slidingRP_viol[ir] = srp['value']
990+
r.slidingRP_viol_forced[ir] = srp['value_forced']
991+
r.max_confidence[ir] = srp['max_confidence']
992+
r.min_contamination[ir] = srp['min_contamination']
993+
r.n_spikes_below2 = srp['n_spikes_below2']
985994

986995
# loop over each cluster to compute the rest of the metrics
987996
for ic in np.arange(nclust):
@@ -1000,29 +1009,36 @@ def quick_unit_metrics(spike_clusters, spike_times, spike_amps, spike_depths,
10001009
r.missed_spikes_est[ic], _, _ = missed_spikes_est(amps, **params['missed_spikes_est'])
10011010
# wonder if there is a need to low-cut this
10021011
r.drift[ic] = np.sum(np.abs(np.diff(depths))) / (tmax - tmin) * 3600
1003-
1004-
r.label = compute_labels(r)
1012+
r.label, r.bitwise_fail = compute_labels(r, return_bitwise=True)
10051013
return r
10061014

10071015

1008-
def compute_labels(r, params=METRICS_PARAMS, return_details=False):
1016+
def compute_labels(r, params=METRICS_PARAMS, return_bitwise=False):
10091017
"""
1010-
From a dataframe or a dictionary of unit metrics, compute a lablel
1018+
From a dataframe or a dictionary of unit metrics, compute a label
10111019
:param r: dictionary or pandas dataframe containing unit qcs
1012-
:param return_details: False (returns a full dictionary of metrics)
1020+
:param return_bitwise: True (returns a full dictionary of metrics)
10131021
:return: vector of proportion of qcs passed between 0 and 1, where 1 denotes an all pass
10141022
"""
1015-
# right now the score is a value between 0 and 1 denoting the proportion of passing qcs
1016-
# we could eventually do a bitwise qc
1023+
# right now the score is a value between 0 and 1 denoting the proportion of passing qcs,
1024+
# where 1 means passing and 0 means failing
10171025
labels = np.c_[
1018-
r.slidingRP_viol,
1026+
r['max_confidence'] >= params['RPmax_confidence'], # this is the least significant bit
10191027
r.noise_cutoff < params['noise_cutoff']['nc_threshold'],
10201028
r.amp_median > params['med_amp_thresh_uv'] / 1e6,
1029+
# add a new metric here on higher significant bits
10211030
]
1022-
if not return_details:
1023-
return np.mean(labels, axis=1)
1024-
column_names = ['slidingRP_viol', 'noise_cutoff', 'amp_median']
1025-
qcdict = {}
1026-
for c in np.arange(labels.shape[1]):
1027-
qcdict[column_names[c]] = labels[:, c]
1028-
return np.mean(labels, axis=1), qcdict
1031+
# The first column takes binary values 001 or 000 to represent fail or pass,
1032+
# the second, 010 or 000, the third, 100 or 000 etc.
1033+
# The bitwise or "sum" produces 111 if all metrics fail, or 000 if all metrics pass
1034+
# All other permutations are also captured, i.e. 110 == 000 || 010 || 100 means
1035+
# the second and third metrics failed and the first metric was a pass
1036+
score = np.mean(labels, axis=1)
1037+
if return_bitwise:
1038+
# note the cast to uint8 casts nan to 0
1039+
# a nan implies no metrics was computed which we mark as a failure here
1040+
n_criteria = labels.shape[1]
1041+
bitwise = np.bitwise_or.reduce(2 ** np.arange(n_criteria) * (~ labels.astype(bool)).astype(np.uint8), axis=1)
1042+
return score, bitwise.astype(np.uint8)
1043+
else:
1044+
return score

brainbox/plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ def driftmap(ts, feat, ax=None, plot_style='bincount',
673673
else:
674674
# compute raster map as a function of site depth
675675
R, times, depths = bincount2D(
676-
ts[iok], feat[iok], t_bin, d_bin, weights=weights)
676+
ts[iok], feat[iok], t_bin, d_bin, weights=weights[iok] if weights is not None else None)
677677
# plot raster map
678678
ax.imshow(R, aspect='auto', cmap='binary', vmin=0, vmax=vmax or np.std(R) * 4,
679679
extent=np.r_[times[[0, -1]], depths[[0, -1]]], origin='lower', **kwargs)

brainbox/tests/test_metrics.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def _assertions(dfm, idf, target_cid):
6262
assert np.allclose(dfm['drift'][idf], np.array(cid) * 100 * 4 * 3.6, rtol=1.1)
6363
assert np.allclose(dfm['firing_rate'][idf], frs, rtol=1.1)
6464
assert np.allclose(dfm['cluster_id'], target_cid)
65-
65+
# test expected bitwise qc values:
66+
expected_labels = 1 - np.sum(np.unpackbits(dfm['bitwise_fail']).reshape(-1, 8), axis=1) / 3
67+
assert np.allclose(dfm['label'], expected_labels)
6668
# check with missing clusters
6769
dfm = quick_unit_metrics(c, t, a, d, cluster_ids=np.arange(5), tbounds=[100, 900])
6870
idf, _ = ismember(np.arange(5), cid)

ibllib/oneibl/data_handlers.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def getData(self, one=None):
5151
for file in self.signature['input_files']:
5252
dfs.append(filter_datasets(session_datasets, filename=file[0], collection=file[1],
5353
wildcards=True, assert_unique=False))
54+
if len(dfs) == 0:
55+
return pd.DataFrame()
5456
df = pd.concat(dfs)
5557

5658
# Some cases the eid is stored in the index. If so we drop this level
@@ -413,23 +415,29 @@ class SDSCDataHandler(DataHandler):
413415
:param signature: input and output file signatures
414416
:param one: ONE instance
415417
"""
418+
416419
def __init__(self, task, session_path, signatures, one=None):
417420
super().__init__(session_path, signatures, one=one)
418421
self.task = task
422+
self.SDSC_PATCH_PATH = SDSC_PATCH_PATH
423+
self.SDSC_ROOT_PATH = SDSC_ROOT_PATH
419424

420425
def setUp(self):
421426
"""Function to create symlinks to necessary data to run tasks."""
422427
df = super().getData()
423428

424-
SDSC_TMP = Path(SDSC_PATCH_PATH.joinpath(self.task.__class__.__name__))
429+
SDSC_TMP = Path(self.SDSC_PATCH_PATH.joinpath(self.task.__class__.__name__))
425430
for i, d in df.iterrows():
426431
file_path = Path(d['session_path']).joinpath(d['rel_path'])
427432
uuid = i
428433
file_uuid = add_uuid_string(file_path, uuid)
429434
file_link = SDSC_TMP.joinpath(file_path)
430435
file_link.parent.mkdir(exist_ok=True, parents=True)
431-
file_link.symlink_to(
432-
Path(SDSC_ROOT_PATH.joinpath(file_uuid)))
436+
try:
437+
file_link.symlink_to(
438+
Path(self.SDSC_ROOT_PATH.joinpath(file_uuid)))
439+
except FileExistsError:
440+
pass
433441

434442
self.task.session_path = SDSC_TMP.joinpath(d['session_path'])
435443

@@ -448,3 +456,20 @@ def cleanUp(self):
448456
"""Function to clean up symlinks created to run task."""
449457
assert SDSC_PATCH_PATH.parts[0:4] == self.task.session_path.parts[0:4]
450458
shutil.rmtree(self.task.session_path)
459+
460+
461+
class PopeyeDataHandler(SDSCDataHandler):
462+
463+
def __init__(self, task, session_path, signatures, one=None):
464+
super().__init__(task, session_path, signatures, one=one)
465+
self.SDSC_PATCH_PATH = Path(os.getenv('SDSC_PATCH_PATH', "/mnt/sdceph/users/ibl/data/quarantine/tasks/"))
466+
self.SDSC_ROOT_PATH = Path("/mnt/sdceph/users/ibl/data")
467+
468+
def uploadData(self, outputs, version, **kwargs):
469+
raise NotImplementedError(
470+
"Cannot register data from Popeye. Login as Datauser and use the RegisterSpikeSortingSDSC task."
471+
)
472+
473+
def cleanUp(self):
474+
"""Symlinks are preserved until registration."""
475+
pass

0 commit comments

Comments
 (0)