Skip to content

Commit a61aa50

Browse files
authored
Merge pull request #436 from int-brain-lab/release/2.8.0
Release/2.8.0
2 parents 482ff95 + a6b956e commit a61aa50

File tree

6 files changed

+185
-43
lines changed

6 files changed

+185
-43
lines changed

brainbox/io/one.py

Lines changed: 170 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,33 @@
11
"""Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment."""
2+
from dataclasses import dataclass
23
import logging
34
import os
5+
from pathlib import Path
46

57
import numpy as np
68
import pandas as pd
79
from scipy.interpolate import interp1d
810

911
from one.api import ONE
12+
import one.alf.io as alfio
1013

1114
from iblutil.util import Bunch
1215

1316
from ibllib.io import spikeglx
1417
from ibllib.io.extractors.training_wheel import extract_wheel_moves, extract_first_movement_times
1518
from ibllib.ephys.neuropixel import SITES_COORDINATES, TIP_SIZE_UM, trace_header
16-
from ibllib.atlas import atlas
17-
from ibllib.atlas import AllenAtlas
19+
from ibllib.atlas import atlas, AllenAtlas
1820
from ibllib.pipes import histology
1921
from ibllib.pipes.ephys_alignment import EphysAlignment
2022

2123
from brainbox.core import TimeSeries
2224
from brainbox.processing import sync
25+
from brainbox.metrics.single_units import quick_unit_metrics
2326

2427
_logger = logging.getLogger('ibllib')
2528

2629

27-
SPIKES_ATTRIBUTES = ['clusters', 'times']
30+
SPIKES_ATTRIBUTES = ['clusters', 'times', 'amps', 'depths']
2831
CLUSTERS_ATTRIBUTES = ['channels', 'depths', 'metrics']
2932

3033

@@ -120,7 +123,7 @@ def _channels_alf2bunch(channels, brain_regions=None):
120123

121124

122125
def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_channels=True, dataset_types=None,
123-
brain_regions=None):
126+
brain_regions=None, return_collection=False):
124127
"""
125128
Generic function to load spike sorting according data using ONE.
126129
@@ -143,7 +146,7 @@ def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_ch
143146
A particular revision return (defaults to latest revision). See `ALF documentation`_ for
144147
details.
145148
return_channels : bool
146-
Defaults to False otherwise loads channels from disk (takes longer)
149+
Defaults to False otherwise loads channels from disk
147150
148151
.. _ALF documentation: https://one.internationalbrainlab.org/alf_intro.html#optional-components
149152
@@ -208,9 +211,9 @@ def _load_channels_locations_from_disk(eid, collection=None, one=None, revision=
208211
# if the spike sorter has not aligned data, try and get the alignment available
209212
if 'brainLocationIds_ccf_2017' not in channels[probe].keys():
210213
aligned_channel_collections = one.list_collections(
211-
eid, filename='channels.brainLocationIds_ccf_2017*', collection=f'alf/{probe}', revision=revision)
214+
eid, filename='channels.brainLocationIds_ccf_2017*', collection=probe_collection, revision=revision)
212215
if len(aligned_channel_collections) == 0:
213-
_logger.warning(f"no resolved alignment dataset found for {eid}/{probe}")
216+
_logger.debug(f"no resolved alignment dataset found for {eid}/{probe}")
214217
continue
215218
_logger.debug(f"looking for a resolved alignment dataset in {aligned_channel_collections}")
216219
ac_collection = _get_spike_sorting_collection(aligned_channel_collections, probe)
@@ -266,8 +269,8 @@ def channel_locations_interpolation(channels_aligned, channels=None, brain_regio
266269

267270

268271
def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, aligned=False,
269-
brain_atlas=None):
270-
print('from traj')
272+
brain_atlas=None, return_source=False):
273+
_logger.debug(f"trying to load from traj {probe}")
271274
channels = Bunch()
272275
brain_atlas = brain_atlas or AllenAtlas
273276
# need to find the collection bruh
@@ -290,10 +293,9 @@ def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, align
290293
xyz = np.array(insertion['json']['xyz_picks']) / 1e6
291294
if resolved:
292295

293-
_logger.info(f'Channel locations for {eid}/{probe} have been resolved. '
294-
f'Channel and cluster locations obtained from ephys aligned histology '
295-
f'track.')
296-
296+
_logger.debug(f'Channel locations for {eid}/{probe} have been resolved. '
297+
f'Channel and cluster locations obtained from ephys aligned histology '
298+
f'track.')
297299
traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe,
298300
provenance='Ephys aligned histology track')[0]
299301
align_key = insertion['json']['extended_qc']['alignment_stored']
@@ -304,12 +306,12 @@ def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, align
304306
brain_atlas=brain_atlas, speedy=True)
305307
chans = ephysalign.get_channel_locations(feature, track)
306308
channels[probe] = _channels_traj2bunch(chans, brain_atlas)
307-
309+
source = 'resolved'
308310
elif counts > 0 and aligned:
309-
_logger.info(f'Channel locations for {eid}/{probe} have not been '
310-
f'resolved. However, alignment flag set to True so channel and cluster'
311-
f' locations will be obtained from latest available ephys aligned '
312-
f'histology track.')
311+
_logger.debug(f'Channel locations for {eid}/{probe} have not been '
312+
f'resolved. However, alignment flag set to True so channel and cluster'
313+
f' locations will be obtained from latest available ephys aligned '
314+
f'histology track.')
313315
# get the latest user aligned channels
314316
traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe,
315317
provenance='Ephys aligned histology track')[0]
@@ -322,28 +324,31 @@ def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, align
322324
chans = ephysalign.get_channel_locations(feature, track)
323325

324326
channels[probe] = _channels_traj2bunch(chans, brain_atlas)
325-
327+
source = 'aligned'
326328
else:
327-
_logger.info(f'Channel locations for {eid}/{probe} have not been resolved. '
328-
f'Channel and cluster locations obtained from histology track.')
329+
_logger.debug(f'Channel locations for {eid}/{probe} have not been resolved. '
330+
f'Channel and cluster locations obtained from histology track.')
329331
# get the channels from histology tracing
330332
xyz = xyz[np.argsort(xyz[:, 2]), :]
331333
chans = histology.interpolate_along_track(xyz, (depths + TIP_SIZE_UM) / 1e6)
332334

333335
channels[probe] = _channels_traj2bunch(chans, brain_atlas)
334-
336+
source = 'traced'
335337
channels[probe]['axial_um'] = chn_coords[:, 1]
336338
channels[probe]['lateral_um'] = chn_coords[:, 0]
337339

338340
else:
339-
_logger.warning(f'Histology tracing for {probe} does not exist. '
340-
f'No channels for {probe}')
341+
_logger.warning(f'Histology tracing for {probe} does not exist. No channels for {probe}')
342+
source = ''
341343
channels = None
342344

343-
return channels
345+
if return_source:
346+
return channels, source
347+
else:
348+
return channels
344349

345350

346-
def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas=None):
351+
def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas=None, return_source=False):
347352
"""
348353
Load the brain locations of each channel for a given session/probe
349354
@@ -360,12 +365,14 @@ def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas
360365
Whether to get the latest user aligned channel when not resolved or use histology track
361366
brain_atlas : ibllib.atlas.BrainAtlas
362367
Brain atlas object (default: Allen atlas)
363-
368+
return_source: bool
369+
if True returns the source of the channel lcoations (default False)
364370
Returns
365371
-------
366372
dict of one.alf.io.AlfBunch
367373
A dict with probe labels as keys, contains channel locations with keys ('acronym',
368374
'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized.
375+
optional: string 'resolved', 'aligned', 'traced' or ''
369376
"""
370377
one = one or ONE()
371378
brain_atlas = brain_atlas or AllenAtlas()
@@ -379,8 +386,8 @@ def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas
379386
brain_regions=brain_atlas.regions)
380387
incomplete_probes = [k for k in channels if 'x' not in channels[k]]
381388
for iprobe in incomplete_probes:
382-
channels_ = _load_channel_locations_traj(eid, probe=iprobe, one=one, aligned=aligned,
383-
brain_atlas=brain_atlas)
389+
channels_, source = _load_channel_locations_traj(eid, probe=iprobe, one=one, aligned=aligned,
390+
brain_atlas=brain_atlas, return_source=True)
384391
if channels_ is not None:
385392
channels[iprobe] = channels_[iprobe]
386393
return channels
@@ -416,7 +423,7 @@ def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike
416423
brain_regions=brain_regions)
417424
spikes, clusters, channels = _load_spike_sorting(**kwargs, return_channels=True)
418425
clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None)
419-
if nested is False:
426+
if nested is False and len(spikes.keys()) == 1:
420427
k = list(spikes.keys())[0]
421428
channels = channels[k]
422429
clusters = clusters[k]
@@ -428,7 +435,7 @@ def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike
428435

429436

430437
def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None,
431-
brain_regions=None):
438+
brain_regions=None, return_collection=False):
432439
"""
433440
From an eid, loads spikes and clusters for all probes
434441
The following set of dataset types are loaded:
@@ -445,18 +452,22 @@ def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sort
445452
:param spike_sorter: name of the spike sorting you want to load (None for default)
446453
:param return_channels: (bool) defaults to False otherwise tries and load channels from disk
447454
:param brain_regions: ibllib.atlas.regions.BrainRegions object - will label acronyms if provided
455+
:param return_collection:(bool - False) if True, returns the collection for loading the data
448456
:return: spikes, clusters (dict of bunch, 1 bunch per probe)
449457
"""
450458
collection = _collection_filter_from_args(probe, spike_sorter)
451459
_logger.debug(f"load spike sorting with collection filter {collection}")
452460
spikes, clusters = _load_spike_sorting(eid=eid, one=one, collection=collection, revision=revision,
453461
return_channels=False, dataset_types=dataset_types,
454462
brain_regions=brain_regions)
455-
return spikes, clusters
463+
if return_collection:
464+
return spikes, clusters, collection
465+
else:
466+
return spikes, clusters
456467

457468

458469
def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, dataset_types=None,
459-
spike_sorter=None, brain_atlas=None):
470+
spike_sorter=None, brain_atlas=None, nested=True, return_collection=False):
460471
"""
461472
For a given eid, get spikes, clusters and channels information, and merges clusters
462473
and channels information before returning all three variables.
@@ -479,6 +490,8 @@ def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, da
479490
available otherwise the default MATLAB kilosort)
480491
brain_atlas : ibllib.atlas.BrainAtlas
481492
Brain atlas object (default: Allen atlas)
493+
return_collection: bool
494+
Returns an extra argument with the collection chosen
482495
483496
Returns
484497
-------
@@ -495,13 +508,21 @@ def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, da
495508
# --- Get spikes and clusters data
496509
one = one or ONE()
497510
brain_atlas = brain_atlas or AllenAtlas()
498-
spikes, clusters = load_spike_sorting(eid, one=one, probe=probe, dataset_types=dataset_types,
499-
spike_sorter=spike_sorter)
511+
spikes, clusters, collection = load_spike_sorting(
512+
eid, one=one, probe=probe, dataset_types=dataset_types, spike_sorter=spike_sorter, return_collection=True)
500513
# -- Get brain regions and assign to clusters
501514
channels = load_channel_locations(eid, one=one, probe=probe, aligned=aligned,
502515
brain_atlas=brain_atlas)
503516
clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None)
504-
return spikes, clusters, channels
517+
if nested is False and len(spikes.keys()) == 1:
518+
k = list(spikes.keys())[0]
519+
channels = channels[k]
520+
clusters = clusters[k]
521+
spikes = spikes[k]
522+
if return_collection:
523+
return spikes, clusters, channels, collection
524+
else:
525+
return spikes, clusters, channels
505526

506527

507528
def load_ephys_session(eid, one=None):
@@ -837,3 +858,116 @@ def load_channels_from_insertion(ins, depths=None, one=None, ba=None):
837858
brain_atlas=ba, speedy=True)
838859
xyz_channels = ephysalign.get_channel_locations(feature, track)
839860
return xyz_channels
861+
862+
863+
@dataclass
864+
class SpikeSortingLoader:
865+
"""Class for loading spike sorting"""
866+
pid: str
867+
one: ONE
868+
atlas: None
869+
# the following properties are the outcome of the post init funciton
870+
eid: str = ''
871+
session_path: Path = ''
872+
collections: list = None
873+
datasets: list = None # list of all datasets belonging to the sesion
874+
# the following properties are the outcome of a reading function
875+
files: dict = None
876+
collection: str = ''
877+
histology: str = '' # 'alf', 'resolved', 'aligned' or 'traced'
878+
spike_sorting_path: Path = None
879+
880+
def __post_init__(self):
881+
self.eid, self.pname = self.one.pid2eid(self.pid)
882+
self.session_path = self.one.eid2path(self.eid)
883+
self.collections = self.one.list_collections(
884+
self.eid, filename='spikes*', collection=f"alf/{self.pname}*")
885+
self.datasets = self.one.list_datasets(self.eid)
886+
self.files = {}
887+
888+
@staticmethod
889+
def _get_attributes(dataset_types):
890+
"""returns attributes to load for spikes and clusters objects"""
891+
if dataset_types is None:
892+
return SPIKES_ATTRIBUTES, CLUSTERS_ATTRIBUTES
893+
else:
894+
spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp]
895+
cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl]
896+
spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes))
897+
cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes))
898+
return spike_attributes, cluster_attributes
899+
900+
def _get_spike_sorting_collection(self, spike_sorter='pykilosort', revision=None):
901+
"""
902+
Filters a list or array of collections to get the relevant spike sorting dataset
903+
if there is a pykilosort, load it
904+
"""
905+
collection = next(filter(lambda c: c == f'alf/{self.pname}/{spike_sorter}', self.collections), None)
906+
# otherwise, prefers the shortest
907+
collection = collection or next(iter(sorted(filter(lambda c: f'alf/{self.pname}' in c, self.collections), key=len)), None)
908+
_logger.debug(f"selecting: {collection} to load amongst candidates: {self.collections}")
909+
return collection
910+
911+
def _download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_types=None):
912+
if len(self.collections) == 0:
913+
return {}, {}, {}
914+
self.collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter)
915+
spike_attributes, cluster_attributes = self._get_attributes(dataset_types)
916+
attributes = {'spikes': spike_attributes, 'clusters': cluster_attributes, 'channels': None}
917+
self.files[obj] = self.one.load_object(self.eid, obj=obj, attribute=attributes[obj],
918+
collection=self.collection, download_only=True)
919+
920+
def download_spike_sorting(self, **kwargs):
921+
"""spike_sorter='pykilosort', dataset_types=None"""
922+
for obj in ['spikes', 'clusters', 'channels']:
923+
self._download_spike_sorting_object(obj=obj, **kwargs)
924+
self.spike_sorting_path = self.files['spikes'][0].parent
925+
926+
def load_spike_sorting(self, **kwargs):
927+
"""spike_sorter='pykilosort', dataset_types=None"""
928+
if len(self.collections) == 0:
929+
return {}, {}, {}
930+
self.download_spike_sorting(**kwargs)
931+
channels = alfio.load_object(self.files['channels'], wildcards=self.one.wildcards)
932+
clusters = alfio.load_object(self.files['clusters'], wildcards=self.one.wildcards)
933+
spikes = alfio.load_object(self.files['spikes'], wildcards=self.one.wildcards)
934+
if 'brainLocationIds_ccf_2017' not in channels:
935+
channels, self.histology = _load_channel_locations_traj(
936+
self.eid, probe=self.pname, one=self.one, brain_atlas=self.atlas, return_source=True)
937+
channels = channels[self.pname]
938+
else:
939+
channels = _channels_alf2bunch(channels, brain_regions=self.atlas.regions)
940+
self.histology = 'alf'
941+
return spikes, clusters, channels
942+
943+
@staticmethod
944+
def merge_clusters(spikes, clusters, channels, cache_dir=None):
945+
"""merge metrics and channels info - optionally saves a clusters.pqt dataframe"""
946+
if spikes == {}:
947+
return
948+
nc = clusters['channels'].size
949+
# recompute metrics if they are not available
950+
metrics = None
951+
if 'metrics' in clusters:
952+
metrics = clusters.pop('metrics')
953+
if metrics.shape[0] != nc:
954+
metrics = None
955+
if metrics is None:
956+
_logger.debug("recompute clusters metrics")
957+
metrics = pd.DataFrame(quick_unit_metrics(
958+
spikes['clusters'], spikes['times'], spikes['amps'], spikes['depths'], cluster_ids=np.arange(nc)))
959+
if isinstance(cache_dir, Path):
960+
metrics.to_parquet(Path(cache_dir).joinpath('clusters.metrics.pqt'))
961+
for k in metrics.keys():
962+
clusters[k] = metrics[k].to_numpy()
963+
964+
for k in channels.keys():
965+
clusters[k] = channels[k][clusters['channels']]
966+
if cache_dir:
967+
pd.DataFrame(clusters).to_parquet(Path(cache_dir).joinpath('clusters.pqt'))
968+
return clusters
969+
970+
@property
971+
def url(self):
972+
"""Gets flatiron URL for the session"""
973+
return str(self.session_path).replace(str(self.one.alyx.cache_dir), 'https://ibl.flatironinstitute.org')

ibllib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "2.7.1"
1+
__version__ = "2.8.0"
22
import warnings
33

44
from ibllib.misc import logger_config

ibllib/dsp/voltage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def decompress_destripe_cbin(sr_file, output_file=None, h=None, wrot=None, appen
316316
channel_labels = detect_bad_channels_cbin(sr)
317317
assert isinstance(sr_file, str) or isinstance(sr_file, Path)
318318
butter_kwargs, k_kwargs, spatial_fcn = _get_destripe_parameters(sr.fs, butter_kwargs, k_kwargs, k_filter)
319-
h = neuropixel.trace_header(version=1) if h is None else h
319+
h = sr.geometry if h is None else h
320320
ncv = h['sample_shift'].size # number of channels
321321
output_file = sr.file_bin.with_suffix('.bin') if output_file is None else output_file
322322
assert output_file != sr.file_bin

0 commit comments

Comments
 (0)