Skip to content

Commit 8a33b2d

Browse files
committed
Merge branch 'rawqc' into develop
2 parents b678717 + 00093a6 commit 8a33b2d

File tree

11 files changed

+272
-91
lines changed

11 files changed

+272
-91
lines changed

brainbox/io/one.py

Lines changed: 73 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ def _channels_alf2bunch(channels, brain_regions=None):
108108
return channels_
109109

110110

111-
def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_channels=True, dataset_types=None):
111+
def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_channels=True, dataset_types=None,
112+
brain_regions=None):
112113
"""
113114
Generic function to load spike sortin according to one searchwords
114115
Will try to load one spike sorting for any probe present for the eid matching the collection
@@ -121,6 +122,7 @@ def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_ch
121122
:param collection: collection filter word - accepts wildcard - can be a combination of spike sorter and probe
122123
:param revision: revision to load
123124
:param return_channels: True
125+
:param brain_regions: ibllib.atlas.regions.BrainRegions object - will label acronyms if provided
124126
:return:
125127
"""
126128
one = one or ONE()
@@ -140,7 +142,8 @@ def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_ch
140142
clusters[pname] = one.load_object(eid, collection=probe_collection, obj='clusters',
141143
attribute=cluster_attributes)
142144

143-
channels = _load_channels_locations_from_disk(eid, collection=collection, one=one, revision=revision)
145+
channels = _load_channels_locations_from_disk(eid, collection=collection, one=one, revision=revision,
146+
brain_regions=brain_regions)
144147

145148
if return_channels:
146149
return spikes, clusters, channels
@@ -179,31 +182,42 @@ def _load_channels_locations_from_disk(eid, collection=None, one=None, revision=
179182
_logger.debug(f"looking for a resolved alignment dataset in {aligned_channel_collections}")
180183
ac_collection = _get_spike_sorting_collection(aligned_channel_collections, probe)
181184
channels_aligned = one.load_object(eid, 'channels', collection=ac_collection)
182-
# oftentimes the channel map for different spike sorters may be different so interpolate the alignment onto
183-
nch = channels[probe]['localCoordinates'].shape[0]
184-
# if there is no spike sorting in the base folder, the alignment doesn't have the localCoordinates field
185-
# so we reconstruct from the Neuropixel map. This only happens for early pykilosort sorts
186-
if 'localCoordinates' in channels_aligned.keys():
187-
aligned_depths = channels_aligned['localCoordinates'][:, 1]
188-
else:
189-
assert channels_aligned['mlapdv'].shape[0] == 384
190-
NEUROPIXEL_VERSION = 1
191-
from ibllib.ephys.neuropixel import trace_header
192-
aligned_depths = trace_header(version=NEUROPIXEL_VERSION)['y']
193-
depth_aligned, ind_aligned = np.unique(aligned_depths, return_index=True)
194-
depths, ind, iinv = np.unique(channels[probe]['localCoordinates'][:, 1], return_index=True, return_inverse=True)
195-
channels[probe]['mlapdv'] = np.zeros((nch, 3))
196-
for i in np.arange(3):
197-
channels[probe]['mlapdv'][:, i] = np.interp(
198-
depths, depth_aligned, channels_aligned['mlapdv'][ind_aligned, i])[iinv]
199-
# the brain locations have to be interpolated by nearest neighbour
200-
fcn_interp = interp1d(depth_aligned, channels_aligned['brainLocationIds_ccf_2017'][ind_aligned], kind='nearest')
201-
channels[probe]['brainLocationIds_ccf_2017'] = fcn_interp(depths)[iinv].astype(np.int32)
185+
channels[probe] = channel_locations_interpolation(channels_aligned, channels[probe])
202186
# only have to reformat channels if we were able to load coordinates from disk
203187
channels[probe] = _channels_alf2bunch(channels[probe], brain_regions=brain_regions)
204188
return channels
205189

206190

191+
def channel_locations_interpolation(channels_aligned, channels):
192+
"""
193+
oftentimes the channel map for different spike sorters may be different so interpolate the alignment onto
194+
if there is no spike sorting in the base folder, the alignment doesn't have the localCoordinates field
195+
so we reconstruct from the Neuropixel map. This only happens for early pykilosort sorts
196+
:param channels_aligned: Bunch or dictionary of aligned channels containing at least keys
197+
'mlapdv' and 'brainLocationIds_ccf_2017' - those are the guide for the interpolation
198+
:param channels: Bunch or dictionary of aligned channels containing at least keys 'localCoordinates'
199+
:return: Bunch or dictionary of channels with extra keys 'mlapdv' and 'brainLocationIds_ccf_2017'
200+
"""
201+
nch = channels['localCoordinates'].shape[0]
202+
if 'localCoordinates' in channels_aligned.keys():
203+
aligned_depths = channels_aligned['localCoordinates'][:, 1]
204+
else:
205+
assert channels_aligned['mlapdv'].shape[0] == 384
206+
NEUROPIXEL_VERSION = 1
207+
from ibllib.ephys.neuropixel import trace_header
208+
aligned_depths = trace_header(version=NEUROPIXEL_VERSION)['y']
209+
depth_aligned, ind_aligned = np.unique(aligned_depths, return_index=True)
210+
depths, ind, iinv = np.unique(channels['localCoordinates'][:, 1], return_index=True, return_inverse=True)
211+
channels['mlapdv'] = np.zeros((nch, 3))
212+
for i in np.arange(3):
213+
channels['mlapdv'][:, i] = np.interp(
214+
depths, depth_aligned, channels_aligned['mlapdv'][ind_aligned, i])[iinv]
215+
# the brain locations have to be interpolated by nearest neighbour
216+
fcn_interp = interp1d(depth_aligned, channels_aligned['brainLocationIds_ccf_2017'][ind_aligned], kind='nearest')
217+
channels['brainLocationIds_ccf_2017'] = fcn_interp(depths)[iinv].astype(np.int32)
218+
return channels
219+
220+
207221
def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, aligned=False,
208222
brain_atlas=None):
209223
print('from traj')
@@ -309,17 +323,43 @@ def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas
309323
return channels
310324

311325

312-
def load_spike_sorting_fast(eid, probe=None, spike_sorter=None, **kwargs):
326+
def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None,
327+
brain_regions=None, nested=True):
313328
"""
314-
Same as load_spike_sorting but with return_channels=True
329+
From an eid, loads spikes and clusters for all probes
330+
The following set of dataset types are loaded:
331+
'clusters.channels',
332+
'clusters.depths',
333+
'clusters.metrics',
334+
'spikes.clusters',
335+
'spikes.times',
336+
'probes.description'
337+
:param eid: experiment UUID or pathlib.Path of the local session
338+
:param one: an instance of OneAlyx
339+
:param probe: name of probe to load in, if not given all probes for session will be loaded
340+
:param dataset_types: additional spikes/clusters objects to add to the standard default list
341+
:param spike_sorter: name of the spike sorting you want to load (None for default)
342+
:param return_channels: (bool) defaults to False otherwise tries and load channels from disk
343+
:param brain_regions: ibllib.atlas.regions.BrainRegions object - will label acronyms if provided
344+
:param nested: if a single probe is required, do not output a dictionary with the probe name as key
345+
:return: spikes, clusters (dict of bunch, 1 bunch per probe)
315346
"""
316347
collection = _collection_filter_from_args(probe, spike_sorter)
317348
_logger.debug(f"load spike sorting with collection filter {collection}")
318-
return _load_spike_sorting(eid, collection=collection, return_channels=True, **kwargs)
349+
kwargs = dict(eid=eid, one=one, collection=collection, revision=revision, dataset_types=dataset_types,
350+
brain_regions=brain_regions)
351+
spikes, clusters, channels = _load_spike_sorting(**kwargs, return_channels=True)
352+
clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None)
353+
if nested is False:
354+
k = list(spikes.keys())[0]
355+
channels = channels[k]
356+
clusters = clusters[k]
357+
spikes = spikes[k]
358+
return spikes, clusters, channels
319359

320360

321-
def load_spike_sorting(eid, one=None, probe=None, dataset_types=None,
322-
spike_sorter=None, revision=None, return_channels=False):
361+
def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None,
362+
brain_regions=None):
323363
"""
324364
From an eid, loads spikes and clusters for all probes
325365
The following set of dataset types are loaded:
@@ -335,12 +375,15 @@ def load_spike_sorting(eid, one=None, probe=None, dataset_types=None,
335375
:param dataset_types: additional spikes/clusters objects to add to the standard default list
336376
:param spike_sorter: name of the spike sorting you want to load (None for default)
337377
:param return_channels: (bool) defaults to False otherwise tries and load channels from disk
338-
:return: spikes, clusters, channels (dict of bunch, 1 bunch per probe)
378+
:param brain_regions: ibllib.atlas.regions.BrainRegions object - will label acronyms if provided
379+
:return: spikes, clusters (dict of bunch, 1 bunch per probe)
339380
"""
340381
collection = _collection_filter_from_args(probe, spike_sorter)
341382
_logger.debug(f"load spike sorting with collection filter {collection}")
342-
return _load_spike_sorting(eid=eid, one=one, collection=collection, revision=revision,
343-
return_channels=return_channels, dataset_types=dataset_types)
383+
spikes, clusters = _load_spike_sorting(eid=eid, one=one, collection=collection, revision=revision,
384+
return_channels=False, dataset_types=dataset_types,
385+
brain_regions=brain_regions)
386+
return spikes, clusters
344387

345388

346389
def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, dataset_types=None,

ibllib/dsp/cadzow.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,25 @@ def trajectory(x, y):
4848

4949

5050
def denoise(WAV, x, y, r, imax=None, niter=1):
51-
WAV_ = np.zeros_like(WAV)
51+
"""
52+
Applies cadzow denoising by de-ranking spatial matrices in frequency domain
53+
:param WAV: np array nc / ns in frequency domain
54+
:param x:
55+
:param y:
56+
:param r:
57+
:param imax:
58+
:param niter:
59+
:return:
60+
"""
61+
WAV_ = np.copy(WAV)
5262
imax = np.minimum(WAV.shape[-1], imax) if imax else WAV.shape[-1]
5363
T, it, itr, trcount = trajectory(x, y)
5464
for ind_f in np.arange(imax):
55-
T[it] = WAV[itr, ind_f]
56-
T_ = derank(T, r)
57-
WAV_[:, ind_f] = np.bincount(itr, weights=np.real(T_[it]))
58-
WAV_[:, ind_f] += 1j * np.bincount(itr, weights=np.imag(T_[it]))
59-
WAV_[:, ind_f] /= trcount
65+
for _ in np.arange(niter):
66+
T[it] = WAV_[itr, ind_f]
67+
T_ = derank(T, r)
68+
WAV_[:, ind_f] = np.bincount(itr, weights=np.real(T_[it]))
69+
WAV_[:, ind_f] += 1j * np.bincount(itr, weights=np.imag(T_[it]))
70+
WAV_[:, ind_f] /= trcount
6071

6172
return WAV_

ibllib/dsp/voltage.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@ def reject_channels(x, fs, butt_kwargs=None, threshold=0.6, trx=1):
4040
def agc(x, wl=.5, si=.002, epsilon=1e-8):
4141
"""
4242
Automatic gain control
43+
w_agc, gain = agc(w, wl=.5, si=.002, epsilon=1e-8)
44+
such as w_agc / gain = w
4345
:param x: seismic array (sample last dimension)
4446
:param wl: window length (secs)
4547
:param si: sampling interval (secs)
4648
:param epsilon: whitening (useful mainly for synthetic data)
47-
:return:
49+
:return: AGC data array, gain applied to data
4850
"""
4951
ns_win = np.round(wl / si / 2) * 2 + 1
5052
w = np.hanning(ns_win)

ibllib/ephys/ephysqc.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from brainbox.metrics.single_units import spike_sorting_metrics
1616
from brainbox.io.spikeglx import stream as sglx_streamer
17-
from ibllib.ephys import sync_probes
17+
from ibllib.ephys import sync_probes, neuropixel, spikes
1818
from ibllib.io import spikeglx
1919
import ibllib.dsp as dsp
2020
from ibllib.qc import base
@@ -31,6 +31,7 @@
3131
BATCHES_SPACING = 300
3232
TMIN = 40
3333
SAMPLE_LENGTH = 1
34+
SPIKE_THRESHOLD_UV = -50 # negative, the threshold used for spike detection on pre-processed raw data
3435

3536

3637
class EphysQC(base.QC):
@@ -89,6 +90,23 @@ def load_data(self) -> None:
8990
bin_file = next(meta_file.parent.glob(f'*{dstype}.*bin'), None)
9091
self.data[f'{dstype}'] = spikeglx.Reader(bin_file, open=True) if bin_file is not None else None
9192

93+
@staticmethod
94+
def _compute_metrics_array(raw, fs, h):
95+
"""
96+
From a numpy array, computes rms on raw data, destripes, computes rms on destriped data
97+
and performs a simple spike detection
98+
:param raw: voltage numpy.array(ntraces, nsamples)
99+
:param fs: sampling frequency (Hz)
100+
:param h: dictionary containing sensor coordinates, see ibllib.ephys.neuropixel.trace_header
101+
:return: 3 numpy vectors nchannels length
102+
"""
103+
destripe = dsp.destripe(raw, fs=fs, neuropixel_version=1)
104+
rms_raw = dsp.rms(raw)
105+
rms_pre_proc = dsp.rms(destripe)
106+
detections = spikes.detection(data=destripe.T, fs=fs, h=h, detect_threshold=SPIKE_THRESHOLD_UV * 1e-6)
107+
spike_rate = np.bincount(detections.trace, minlength=raw.shape[0]).astype(np.float32)
108+
return rms_raw, rms_pre_proc, spike_rate
109+
92110
def run(self, update: bool = False, overwrite: bool = True, stream: bool = None, **kwargs) -> (str, dict):
93111
"""
94112
Run QC on samples of the .ap file, and on the entire file for .lf data if it is present.
@@ -109,39 +127,47 @@ def run(self, update: bool = False, overwrite: bool = True, stream: bool = None,
109127
# TODO: This should go a a separate function once we have a spikeglx.Streamer that behaves like the Reader
110128
if self.data.ap_meta:
111129
rms_file = self.probe_path.joinpath("_iblqc_ephysChannels.apRMS.npy")
112-
if rms_file.exists() and not overwrite:
130+
spike_rate_file = self.probe_path.joinpath("_iblqc_ephysChannels.rawSpikeRates.npy")
131+
if all([rms_file.exists(), spike_rate_file.exists()]) and not overwrite:
113132
_logger.warning(f'RMS map already exists for .ap data in {self.probe_path}, skipping. '
114133
f'Use overwrite option.')
115134
median_rms = np.load(rms_file)
116135
else:
117136
rl = self.data.ap_meta.fileTimeSecs
118-
nc = spikeglx._get_nchannels_from_meta(self.data.ap_meta)
137+
nsync = len(spikeglx._get_sync_trace_indices_from_meta(self.data.ap_meta))
138+
nc = spikeglx._get_nchannels_from_meta(self.data.ap_meta) - nsync
139+
neuropixel_version = spikeglx._get_neuropixel_major_version_from_meta(self.data.ap_meta)
140+
# verify that the channel layout is correct according to IBL layout
141+
h = neuropixel.trace_header(neuropixel_version)
142+
th = spikeglx._geometry_from_meta(self.data.ap_meta)
143+
if not (np.all(h['x'] == th['x']) and np.all(h['y'] == th['y'])):
144+
_logger.critical("Channel geometry seems incorrect")
145+
raise ValueError("Wrong Neuropixel channel mapping used - ABORT")
119146
t0s = np.arange(TMIN, rl - SAMPLE_LENGTH, BATCHES_SPACING)
120-
all_rms = np.zeros((2, nc - 1, t0s.shape[0]))
147+
all_rms = np.zeros((2, nc, t0s.shape[0]))
148+
all_srs = np.zeros((nc, t0s.shape[0]))
121149
# If the ap.bin file is not present locally, stream it
122150
if self.data.ap is None and self.stream is True:
123151
_logger.warning(f'Streaming .ap data to compute RMS samples for probe {self.pid}')
124152
for i, t0 in enumerate(tqdm(t0s)):
125153
sr, _ = sglx_streamer(self.pid, t0=t0, nsecs=1, one=self.one, remove_cached=True)
126-
raw = sr[:, :-1].T
127-
destripe = dsp.destripe(raw, fs=sr.fs, neuropixel_version=1)
128-
all_rms[0, :, i] = dsp.rms(raw)
129-
all_rms[1, :, i] = dsp.rms(destripe)
154+
raw = sr[:, :-nsync].T
155+
all_rms[0, :, i], all_rms[1, :, i], all_srs[:, i] = self._compute_metrics_array(raw, sr.fs, h)
130156
elif self.data.ap is None and self.stream is not True:
131157
_logger.warning('Raw .ap data is not available locally. Run with stream=True in order to stream '
132158
'data for calculating RMS samples.')
133159
else:
134160
_logger.info(f'Computing RMS samples for .ap data using local data in {self.probe_path}')
135161
for i, t0 in enumerate(t0s):
136162
sl = slice(int(t0 * self.data.ap.fs), int((t0 + SAMPLE_LENGTH) * self.data.ap.fs))
137-
raw = self.data.ap[sl, :-1].T
138-
destripe = dsp.destripe(raw, fs=self.data.ap.fs, neuropixel_version=1)
139-
all_rms[0, :, i] = dsp.rms(raw)
140-
all_rms[1, :, i] = dsp.rms(destripe)
163+
raw = self.data.ap[sl, :-nsync].T
164+
all_rms[0, :, i], all_rms[1, :, i], all_srs[:, i] = self._compute_metrics_array(raw, self.data.ap.fs, h)
141165
# Calculate the median RMS across all samples per channel
142166
median_rms = np.median(all_rms, axis=-1)
167+
median_spike_rate = np.median(all_srs, axis=-1)
143168
np.save(rms_file, median_rms)
144-
qc_files.append(rms_file)
169+
np.save(spike_rate_file, median_spike_rate)
170+
qc_files.extend([rms_file, spike_rate_file])
145171

146172
for p in [10, 90]:
147173
self.metrics[f'apRms_p{p}_raw'] = np.format_float_scientific(np.percentile(median_rms[0, :], p),

0 commit comments

Comments
 (0)