Skip to content

Commit 5efc723

Browse files
authored
Merge pull request #428 from int-brain-lab/olivier
Olivier
2 parents d61ee27 + 169a809 commit 5efc723

File tree

11 files changed

+305
-145
lines changed

11 files changed

+305
-145
lines changed

brainbox/io/one.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,6 @@ def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, align
276276
collections = one.list_collections(eid, filename='channels*', collection=collection,
277277
revision=revision)
278278
probe_collection = _get_spike_sorting_collection(collections, probe)
279-
print(probe_collection)
280279
chn_coords = one.load_dataset(eid, 'channels.localCoordinates', collection=probe_collection)
281280
depths = chn_coords[:, 1]
282281

@@ -388,7 +387,7 @@ def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas
388387

389388

390389
def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None,
391-
brain_regions=None, nested=True):
390+
brain_regions=None, nested=True, collection=None, return_collection=False):
392391
"""
393392
From an eid, loads spikes and clusters for all probes
394393
The following set of dataset types are loaded:
@@ -403,12 +402,15 @@ def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike
403402
:param probe: name of probe to load in, if not given all probes for session will be loaded
404403
:param dataset_types: additional spikes/clusters objects to add to the standard default list
405404
:param spike_sorter: name of the spike sorting you want to load (None for default)
405+
:param collection: name of the spike sorting collection to load - exclusive with spike sorter name ex: "alf/probe00"
406406
:param return_channels: (bool) defaults to False otherwise tries and load channels from disk
407407
:param brain_regions: ibllib.atlas.regions.BrainRegions object - will label acronyms if provided
408408
:param nested: if a single probe is required, do not output a dictionary with the probe name as key
409-
:return: spikes, clusters (dict of bunch, 1 bunch per probe)
409+
:param return_collection: (False) if True, will return the collection used to load
410+
:return: spikes, clusters, channels (dict of bunch, 1 bunch per probe)
410411
"""
411-
collection = _collection_filter_from_args(probe, spike_sorter)
412+
if collection is None:
413+
collection = _collection_filter_from_args(probe, spike_sorter)
412414
_logger.debug(f"load spike sorting with collection filter {collection}")
413415
kwargs = dict(eid=eid, one=one, collection=collection, revision=revision, dataset_types=dataset_types,
414416
brain_regions=brain_regions)
@@ -419,7 +421,10 @@ def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike
419421
channels = channels[k]
420422
clusters = clusters[k]
421423
spikes = spikes[k]
422-
return spikes, clusters, channels
424+
if return_collection:
425+
return spikes, clusters, channels, collection
426+
else:
427+
return spikes, clusters, channels
423428

424429

425430
def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None,

examples/one/histology/register_lasagna_tracks_alyx.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"""
2525
# Author: Olivier, Gaelle
2626
from pathlib import Path
27-
27+
from ibllib.atlas import AllenAtlas
2828
from one.api import ONE
2929

3030
from ibllib.pipes import histology
@@ -42,11 +42,12 @@
4242

4343
# ======== DO NOT EDIT BELOW ====
4444
one = ONE(base_url=ALYX_URL)
45+
ba = AllenAtlas()
4546

4647
if EXAMPLE_OVERWRITE:
4748
# TODO Olivier : Function to download examples folder
4849
cachepath = Path(one.alyx.cache_dir)
4950
path_tracks = cachepath.joinpath('examples', 'histology', 'tracks_to_add')
5051

51-
histology.register_track_files(path_tracks=path_tracks, one=one, overwrite=True)
52-
histology.detect_missing_histology_tracks(path_tracks=path_tracks, one=one)
52+
histology.register_track_files(path_tracks=path_tracks, one=one, overwrite=True, brain_atlas=ba)
53+
histology.detect_missing_histology_tracks(path_tracks=path_tracks, one=one, brain_atlas=ba)

ibllib/atlas/atlas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import matplotlib.pyplot as plt
44
from pathlib import Path, PurePosixPath
5-
5+
from functools import lru_cache
66
import numpy as np
77
import nrrd
88

@@ -710,6 +710,7 @@ def get_brain_entry(traj, brain_atlas):
710710
return Insertion._get_surface_intersection(traj, brain_atlas, surface='top')
711711

712712

713+
@lru_cache(maxsize=1)
713714
class AllenAtlas(BrainAtlas):
714715
"""
715716
Instantiates an atlas.BrainAtlas corresponding to the Allen CCF at the given resolution

ibllib/dsp/voltage.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def fk(x, si=.002, dx=1, vbounds=None, btype='highpass', ntr_pad=0, ntr_tap=None
109109
return xf / gain
110110

111111

112-
def car(x, collection=None, lagc=300, butter_kwargs=None):
112+
def car(x, collection=None, lagc=300, butter_kwargs=None, **kwargs):
113113
"""
114114
Applies common average referencing with optional automatic gain control
115115
:param x: the input array to be filtered. dimension, the filtering is considering
@@ -223,7 +223,22 @@ def interpolate_bad_channels(data, channel_labels=None, h=None, p=1.3, kriging_d
223223
return data
224224

225225

226-
def destripe(x, fs, neuropixel_version=1, butter_kwargs=None, k_kwargs=None, channel_labels=None):
226+
def _get_destripe_parameters(fs, butter_kwargs, k_kwargs, k_filter):
227+
"""gets the default params for destripe. This is used for both the destripe fcn on a
228+
numpy array and the function that actuates on a cbin file"""
229+
if butter_kwargs is None:
230+
butter_kwargs = {'N': 3, 'Wn': 300 / fs * 2, 'btype': 'highpass'}
231+
if k_kwargs is None:
232+
k_kwargs = {'ntr_pad': 60, 'ntr_tap': 0, 'lagc': 3000,
233+
'butter_kwargs': {'N': 3, 'Wn': 0.01, 'btype': 'highpass'}}
234+
if k_filter:
235+
spatial_fcn = lambda dat: kfilt(dat, **k_kwargs) # noqa
236+
else:
237+
spatial_fcn = lambda dat: car(dat, **k_kwargs) # noqa
238+
return butter_kwargs, k_kwargs, spatial_fcn
239+
240+
241+
def destripe(x, fs, neuropixel_version=1, butter_kwargs=None, k_kwargs=None, channel_labels=None, k_filter=True):
227242
"""Super Car (super slow also...) - far from being set in stone but a good workflow example
228243
:param x: demultiplexed array (nc, ns)
229244
:param fs: sampling frequency
@@ -241,18 +256,10 @@ def destripe(x, fs, neuropixel_version=1, butter_kwargs=None, k_kwargs=None, cha
241256
:param butter_kwargs: (optional, None) butterworth params, see the code for the defaults dict
242257
:param k_kwargs: (optional, None) K-filter params, see the code for the defaults dict
243258
can also be set to 'car', in which case the median accross channels will be subtracted
259+
:param k_filter (True): applies k-filter by default, otherwise, apply CAR.
244260
:return: x, filtered array
245261
"""
246-
if butter_kwargs is None:
247-
butter_kwargs = {'N': 3, 'Wn': 300 / fs * 2, 'btype': 'highpass'}
248-
if k_kwargs is None:
249-
k_kwargs = {'ntr_pad': 60, 'ntr_tap': 0, 'lagc': 3000,
250-
'butter_kwargs': {'N': 3, 'Wn': 0.01, 'btype': 'highpass'}}
251-
spatial_fcn = lambda dat: kfilt(dat, **k_kwargs) # noqa
252-
elif isinstance(k_kwargs, dict):
253-
spatial_fcn = lambda dat: kfilt(dat, **k_kwargs) # noqa
254-
else:
255-
spatial_fcn = lambda dat: car(dat, lagc=int(0.1 * fs)) # noqa
262+
butter_kwargs, k_kwargs, spatial_fcn = _get_destripe_parameters(fs, butter_kwargs, k_kwargs, k_filter)
256263
h = neuropixel.trace_header(version=neuropixel_version)
257264
if channel_labels is True:
258265
channel_labels, _ = detect_bad_channels(x, fs)
@@ -262,8 +269,7 @@ def destripe(x, fs, neuropixel_version=1, butter_kwargs=None, k_kwargs=None, cha
262269
# channel interpolation
263270
# apply ADC shift
264271
if neuropixel_version is not None:
265-
sample_shift = h['sample_shift'] if (30000 / fs) < 10 else h['sample_shift'] * fs / 30000
266-
x = fshift(x, sample_shift, axis=1)
272+
x = fshift(x, h['sample_shift'], axis=1)
267273
# apply spatial filter only on channels that are inside of the brain
268274
if channel_labels is not None:
269275
x = interpolate_bad_channels(x, channel_labels, h)
@@ -275,7 +281,8 @@ def destripe(x, fs, neuropixel_version=1, butter_kwargs=None, k_kwargs=None, cha
275281

276282

277283
def decompress_destripe_cbin(sr_file, output_file=None, h=None, wrot=None, append=False, nc_out=None, butter_kwargs=None,
278-
dtype=np.int16, ns2add=0, nbatch=None, nprocesses=None, compute_rms=True, reject_channels=True):
284+
dtype=np.int16, ns2add=0, nbatch=None, nprocesses=None, compute_rms=True, reject_channels=True,
285+
k_kwargs=None, k_filter=True):
279286
"""
280287
From a spikeglx Reader object, decompresses and apply ADC.
281288
Saves output as a flat binary file in int16
@@ -295,6 +302,8 @@ def decompress_destripe_cbin(sr_file, output_file=None, h=None, wrot=None, appen
295302
interp 3:outside of brain and discard
296303
:param reject_channels: (True) detects noisy or bad channels and interpolate them. Channels outside of the brain are left
297304
untouched
305+
:param k_kwargs: (None) arguments for the kfilter function
306+
:param k_filter: (True) Performs a k-filter - if False will do median common average referencing
298307
:return:
299308
"""
300309
import pyfftw
@@ -306,9 +315,7 @@ def decompress_destripe_cbin(sr_file, output_file=None, h=None, wrot=None, appen
306315
if reject_channels: # get bad channels if option is on
307316
channel_labels = detect_bad_channels_cbin(sr)
308317
assert isinstance(sr_file, str) or isinstance(sr_file, Path)
309-
butter_kwargs = butter_kwargs or {'N': 3, 'Wn': 300 / sr.fs * 2, 'btype': 'highpass'}
310-
k_kwargs = {'ntr_pad': 60, 'ntr_tap': 0, 'lagc': 3000,
311-
'butter_kwargs': {'N': 3, 'Wn': 0.01, 'btype': 'highpass'}}
318+
butter_kwargs, k_kwargs, spatial_fcn = _get_destripe_parameters(sr.fs, butter_kwargs, k_kwargs, k_filter)
312319
h = neuropixel.trace_header(version=1) if h is None else h
313320
ncv = h['sample_shift'].size # number of channels
314321
output_file = sr.file_bin.with_suffix('.bin') if output_file is None else output_file
@@ -412,9 +419,9 @@ def my_function(i_chunk, n_chunk):
412419
if reject_channels:
413420
chunk = interpolate_bad_channels(chunk, channel_labels, h=h)
414421
inside_brain = np.where(channel_labels != 3)[0]
415-
chunk[inside_brain, :] = kfilt(chunk[inside_brain, :], **k_kwargs) # apply the k-filter
422+
chunk[inside_brain, :] = spatial_fcn(chunk[inside_brain, :]) # apply the k-filter / CAR
416423
else:
417-
chunk = kfilt(chunk, **k_kwargs) # apply the k-filter
424+
chunk = spatial_fcn(chunk) # apply the k-filter / CAR
418425
# add back sync trace and save
419426
chunk = np.r_[chunk, _sr[first_s:last_s, ncv:].T].T
420427
intnorm = 1 / _sr.channel_conversion_sample2v['ap'] if dtype == np.int16 else 1.
@@ -476,7 +483,7 @@ def normalize(z):
476483
return rcor
477484

478485

479-
def detect_bad_channels(raw, fs, similarity_threshold=(-0.5, 1), psd_hf_threshold=0.02):
486+
def detect_bad_channels(raw, fs, similarity_threshold=(-0.5, 1), psd_hf_threshold=None):
480487
"""
481488
Bad channels detection for Neuropixel probes
482489
Labels channels
@@ -540,7 +547,9 @@ def nxcor(x, ref):
540547
raw = raw - np.mean(raw, axis=-1)[:, np.newaxis] # removes DC offset
541548
xcor = channels_similarity(raw)
542549
fscale, psd = scipy.signal.welch(raw * 1e6, fs=fs) # units; uV ** 2 / Hz
543-
550+
if psd_hf_threshold is None:
551+
# the LFP band data is obviously much stronger so auto-adjust the default threshold
552+
psd_hf_threshold = 1.4 if fs < 5000 else 0.02
544553
sos_hp = scipy.signal.butter(**{'N': 3, 'Wn': 300 / fs * 2, 'btype': 'highpass'}, output='sos')
545554
hf = scipy.signal.sosfiltfilt(sos_hp, raw)
546555
xcorf = channels_similarity(hf)

ibllib/pipes/ephys_preprocessing.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ibllib.pipes import tasks
2121
from ibllib.pipes.training_preprocessing import TrainingRegisterRaw as EphysRegisterRaw
2222
from ibllib.pipes.misc import create_alyx_probe_insertions
23+
from ibllib.qc.alignment_qc import get_aligned_channels
2324
from ibllib.qc.task_extractors import TaskQCExtractor
2425
from ibllib.qc.task_metrics import TaskQC
2526
from ibllib.qc.camera import run_all_qc as run_camera_qc
@@ -406,6 +407,18 @@ def _run(self, probes=None):
406407
tar_dir.mkdir(parents=True, exist_ok=True)
407408
out = spikes.ks2_to_tar(ks2_dir, tar_dir, force=self.FORCE_RERUN)
408409
out_files.extend(out)
410+
411+
if self.one:
412+
eid = self.one.path2eid(self.session_path, query_type='remote')
413+
ins = self.one.alyx.rest('insertions', 'list', session=eid, name=label)
414+
if len(ins) != 0:
415+
resolved = ins[0].get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \
416+
get('alignment_resolved', False)
417+
if resolved:
418+
chns = np.load(probe_out_path.joinpath('channels.localCoordinates.npy'))
419+
out = get_aligned_channels(ins[0], chns, one=self.one, save_dir=probe_out_path)
420+
out_files.extend(out)
421+
409422
except BaseException:
410423
_logger.error(traceback.format_exc())
411424
self.status = -1

0 commit comments

Comments
 (0)