11"""Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment."""
2+ from dataclasses import dataclass
23import logging
34import os
5+ from pathlib import Path
46
57import numpy as np
68import pandas as pd
79from scipy .interpolate import interp1d
810
911from one .api import ONE
12+ import one .alf .io as alfio
1013
1114from iblutil .util import Bunch
1215
1316from ibllib .io import spikeglx
1417from ibllib .io .extractors .training_wheel import extract_wheel_moves , extract_first_movement_times
1518from ibllib .ephys .neuropixel import SITES_COORDINATES , TIP_SIZE_UM , trace_header
16- from ibllib .atlas import atlas
17- from ibllib .atlas import AllenAtlas
19+ from ibllib .atlas import atlas , AllenAtlas
1820from ibllib .pipes import histology
1921from ibllib .pipes .ephys_alignment import EphysAlignment
2022
2123from brainbox .core import TimeSeries
2224from brainbox .processing import sync
25+ from brainbox .metrics .single_units import quick_unit_metrics
2326
2427_logger = logging .getLogger ('ibllib' )
2528
2629
27- SPIKES_ATTRIBUTES = ['clusters' , 'times' ]
30+ SPIKES_ATTRIBUTES = ['clusters' , 'times' , 'amps' , 'depths' ]
2831CLUSTERS_ATTRIBUTES = ['channels' , 'depths' , 'metrics' ]
2932
3033
@@ -120,7 +123,7 @@ def _channels_alf2bunch(channels, brain_regions=None):
120123
121124
122125def _load_spike_sorting (eid , one = None , collection = None , revision = None , return_channels = True , dataset_types = None ,
123- brain_regions = None ):
126+ brain_regions = None , return_collection = False ):
124127 """
125128 Generic function to load spike sorting according data using ONE.
126129
@@ -143,7 +146,7 @@ def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_ch
143146 A particular revision return (defaults to latest revision). See `ALF documentation`_ for
144147 details.
145148 return_channels : bool
146- Defaults to False otherwise loads channels from disk (takes longer)
149+ Defaults to False otherwise loads channels from disk
147150
148151 .. _ALF documentation: https://one.internationalbrainlab.org/alf_intro.html#optional-components
149152
@@ -208,9 +211,9 @@ def _load_channels_locations_from_disk(eid, collection=None, one=None, revision=
208211 # if the spike sorter has not aligned data, try and get the alignment available
209212 if 'brainLocationIds_ccf_2017' not in channels [probe ].keys ():
210213 aligned_channel_collections = one .list_collections (
211- eid , filename = 'channels.brainLocationIds_ccf_2017*' , collection = f'alf/ { probe } ' , revision = revision )
214+ eid , filename = 'channels.brainLocationIds_ccf_2017*' , collection = probe_collection , revision = revision )
212215 if len (aligned_channel_collections ) == 0 :
213- _logger .warning (f"no resolved alignment dataset found for { eid } /{ probe } " )
216+ _logger .debug (f"no resolved alignment dataset found for { eid } /{ probe } " )
214217 continue
215218 _logger .debug (f"looking for a resolved alignment dataset in { aligned_channel_collections } " )
216219 ac_collection = _get_spike_sorting_collection (aligned_channel_collections , probe )
@@ -266,8 +269,8 @@ def channel_locations_interpolation(channels_aligned, channels=None, brain_regio
266269
267270
268271def _load_channel_locations_traj (eid , probe = None , one = None , revision = None , aligned = False ,
269- brain_atlas = None ):
270- print ( ' from traj' )
272+ brain_atlas = None , return_source = False ):
273+ _logger . debug ( f"trying to load from traj { probe } " )
271274 channels = Bunch ()
272275 brain_atlas = brain_atlas or AllenAtlas
273276 # need to find the collection bruh
@@ -290,10 +293,9 @@ def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, align
290293 xyz = np .array (insertion ['json' ]['xyz_picks' ]) / 1e6
291294 if resolved :
292295
293- _logger .info (f'Channel locations for { eid } /{ probe } have been resolved. '
294- f'Channel and cluster locations obtained from ephys aligned histology '
295- f'track.' )
296-
296+ _logger .debug (f'Channel locations for { eid } /{ probe } have been resolved. '
297+ f'Channel and cluster locations obtained from ephys aligned histology '
298+ f'track.' )
297299 traj = one .alyx .rest ('trajectories' , 'list' , session = eid , probe = probe ,
298300 provenance = 'Ephys aligned histology track' )[0 ]
299301 align_key = insertion ['json' ]['extended_qc' ]['alignment_stored' ]
@@ -304,12 +306,12 @@ def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, align
304306 brain_atlas = brain_atlas , speedy = True )
305307 chans = ephysalign .get_channel_locations (feature , track )
306308 channels [probe ] = _channels_traj2bunch (chans , brain_atlas )
307-
309+ source = 'resolved'
308310 elif counts > 0 and aligned :
309- _logger .info (f'Channel locations for { eid } /{ probe } have not been '
310- f'resolved. However, alignment flag set to True so channel and cluster'
311- f' locations will be obtained from latest available ephys aligned '
312- f'histology track.' )
311+ _logger .debug (f'Channel locations for { eid } /{ probe } have not been '
312+ f'resolved. However, alignment flag set to True so channel and cluster'
313+ f' locations will be obtained from latest available ephys aligned '
314+ f'histology track.' )
313315 # get the latest user aligned channels
314316 traj = one .alyx .rest ('trajectories' , 'list' , session = eid , probe = probe ,
315317 provenance = 'Ephys aligned histology track' )[0 ]
@@ -322,28 +324,31 @@ def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, align
322324 chans = ephysalign .get_channel_locations (feature , track )
323325
324326 channels [probe ] = _channels_traj2bunch (chans , brain_atlas )
325-
327+ source = 'aligned'
326328 else :
327- _logger .info (f'Channel locations for { eid } /{ probe } have not been resolved. '
328- f'Channel and cluster locations obtained from histology track.' )
329+ _logger .debug (f'Channel locations for { eid } /{ probe } have not been resolved. '
330+ f'Channel and cluster locations obtained from histology track.' )
329331 # get the channels from histology tracing
330332 xyz = xyz [np .argsort (xyz [:, 2 ]), :]
331333 chans = histology .interpolate_along_track (xyz , (depths + TIP_SIZE_UM ) / 1e6 )
332334
333335 channels [probe ] = _channels_traj2bunch (chans , brain_atlas )
334-
336+ source = 'traced'
335337 channels [probe ]['axial_um' ] = chn_coords [:, 1 ]
336338 channels [probe ]['lateral_um' ] = chn_coords [:, 0 ]
337339
338340 else :
339- _logger .warning (f'Histology tracing for { probe } does not exist. '
340- f'No channels for { probe } ' )
341+ _logger .warning (f'Histology tracing for { probe } does not exist. No channels for { probe } ' )
342+ source = ''
341343 channels = None
342344
343- return channels
345+ if return_source :
346+ return channels , source
347+ else :
348+ return channels
344349
345350
346- def load_channel_locations (eid , probe = None , one = None , aligned = False , brain_atlas = None ):
351+ def load_channel_locations (eid , probe = None , one = None , aligned = False , brain_atlas = None , return_source = False ):
347352 """
348353 Load the brain locations of each channel for a given session/probe
349354
@@ -360,12 +365,14 @@ def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas
360365 Whether to get the latest user aligned channel when not resolved or use histology track
361366 brain_atlas : ibllib.atlas.BrainAtlas
362367 Brain atlas object (default: Allen atlas)
363-
368+ return_source: bool
369+ if True returns the source of the channel lcoations (default False)
364370 Returns
365371 -------
366372 dict of one.alf.io.AlfBunch
367373 A dict with probe labels as keys, contains channel locations with keys ('acronym',
368374 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized.
375+ optional: string 'resolved', 'aligned', 'traced' or ''
369376 """
370377 one = one or ONE ()
371378 brain_atlas = brain_atlas or AllenAtlas ()
@@ -379,8 +386,8 @@ def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas
379386 brain_regions = brain_atlas .regions )
380387 incomplete_probes = [k for k in channels if 'x' not in channels [k ]]
381388 for iprobe in incomplete_probes :
382- channels_ = _load_channel_locations_traj (eid , probe = iprobe , one = one , aligned = aligned ,
383- brain_atlas = brain_atlas )
389+ channels_ , source = _load_channel_locations_traj (eid , probe = iprobe , one = one , aligned = aligned ,
390+ brain_atlas = brain_atlas , return_source = True )
384391 if channels_ is not None :
385392 channels [iprobe ] = channels_ [iprobe ]
386393 return channels
@@ -416,7 +423,7 @@ def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike
416423 brain_regions = brain_regions )
417424 spikes , clusters , channels = _load_spike_sorting (** kwargs , return_channels = True )
418425 clusters = merge_clusters_channels (clusters , channels , keys_to_add_extra = None )
419- if nested is False :
426+ if nested is False and len ( spikes . keys ()) == 1 :
420427 k = list (spikes .keys ())[0 ]
421428 channels = channels [k ]
422429 clusters = clusters [k ]
@@ -428,7 +435,7 @@ def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike
428435
429436
430437def load_spike_sorting (eid , one = None , probe = None , dataset_types = None , spike_sorter = None , revision = None ,
431- brain_regions = None ):
438+ brain_regions = None , return_collection = False ):
432439 """
433440 From an eid, loads spikes and clusters for all probes
434441 The following set of dataset types are loaded:
@@ -445,18 +452,22 @@ def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sort
445452 :param spike_sorter: name of the spike sorting you want to load (None for default)
446453 :param return_channels: (bool) defaults to False otherwise tries and load channels from disk
447454 :param brain_regions: ibllib.atlas.regions.BrainRegions object - will label acronyms if provided
455+ :param return_collection:(bool - False) if True, returns the collection for loading the data
448456 :return: spikes, clusters (dict of bunch, 1 bunch per probe)
449457 """
450458 collection = _collection_filter_from_args (probe , spike_sorter )
451459 _logger .debug (f"load spike sorting with collection filter { collection } " )
452460 spikes , clusters = _load_spike_sorting (eid = eid , one = one , collection = collection , revision = revision ,
453461 return_channels = False , dataset_types = dataset_types ,
454462 brain_regions = brain_regions )
455- return spikes , clusters
463+ if return_collection :
464+ return spikes , clusters , collection
465+ else :
466+ return spikes , clusters
456467
457468
458469def load_spike_sorting_with_channel (eid , one = None , probe = None , aligned = False , dataset_types = None ,
459- spike_sorter = None , brain_atlas = None ):
470+ spike_sorter = None , brain_atlas = None , nested = True , return_collection = False ):
460471 """
461472 For a given eid, get spikes, clusters and channels information, and merges clusters
462473 and channels information before returning all three variables.
@@ -479,6 +490,8 @@ def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, da
479490 available otherwise the default MATLAB kilosort)
480491 brain_atlas : ibllib.atlas.BrainAtlas
481492 Brain atlas object (default: Allen atlas)
493+ return_collection: bool
494+ Returns an extra argument with the collection chosen
482495
483496 Returns
484497 -------
@@ -495,13 +508,21 @@ def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, da
495508 # --- Get spikes and clusters data
496509 one = one or ONE ()
497510 brain_atlas = brain_atlas or AllenAtlas ()
498- spikes , clusters = load_spike_sorting ( eid , one = one , probe = probe , dataset_types = dataset_types ,
499- spike_sorter = spike_sorter )
511+ spikes , clusters , collection = load_spike_sorting (
512+ eid , one = one , probe = probe , dataset_types = dataset_types , spike_sorter = spike_sorter , return_collection = True )
500513 # -- Get brain regions and assign to clusters
501514 channels = load_channel_locations (eid , one = one , probe = probe , aligned = aligned ,
502515 brain_atlas = brain_atlas )
503516 clusters = merge_clusters_channels (clusters , channels , keys_to_add_extra = None )
504- return spikes , clusters , channels
517+ if nested is False and len (spikes .keys ()) == 1 :
518+ k = list (spikes .keys ())[0 ]
519+ channels = channels [k ]
520+ clusters = clusters [k ]
521+ spikes = spikes [k ]
522+ if return_collection :
523+ return spikes , clusters , channels , collection
524+ else :
525+ return spikes , clusters , channels
505526
506527
507528def load_ephys_session (eid , one = None ):
@@ -837,3 +858,116 @@ def load_channels_from_insertion(ins, depths=None, one=None, ba=None):
837858 brain_atlas = ba , speedy = True )
838859 xyz_channels = ephysalign .get_channel_locations (feature , track )
839860 return xyz_channels
861+
862+
863+ @dataclass
864+ class SpikeSortingLoader :
865+ """Class for loading spike sorting"""
866+ pid : str
867+ one : ONE
868+ atlas : None
869+ # the following properties are the outcome of the post init funciton
870+ eid : str = ''
871+ session_path : Path = ''
872+ collections : list = None
873+ datasets : list = None # list of all datasets belonging to the sesion
874+ # the following properties are the outcome of a reading function
875+ files : dict = None
876+ collection : str = ''
877+ histology : str = '' # 'alf', 'resolved', 'aligned' or 'traced'
878+ spike_sorting_path : Path = None
879+
880+ def __post_init__ (self ):
881+ self .eid , self .pname = self .one .pid2eid (self .pid )
882+ self .session_path = self .one .eid2path (self .eid )
883+ self .collections = self .one .list_collections (
884+ self .eid , filename = 'spikes*' , collection = f"alf/{ self .pname } *" )
885+ self .datasets = self .one .list_datasets (self .eid )
886+ self .files = {}
887+
888+ @staticmethod
889+ def _get_attributes (dataset_types ):
890+ """returns attributes to load for spikes and clusters objects"""
891+ if dataset_types is None :
892+ return SPIKES_ATTRIBUTES , CLUSTERS_ATTRIBUTES
893+ else :
894+ spike_attributes = [sp .split ('.' )[1 ] for sp in dataset_types if 'spikes.' in sp ]
895+ cluster_attributes = [cl .split ('.' )[1 ] for cl in dataset_types if 'clusters.' in cl ]
896+ spike_attributes = list (set (SPIKES_ATTRIBUTES + spike_attributes ))
897+ cluster_attributes = list (set (CLUSTERS_ATTRIBUTES + cluster_attributes ))
898+ return spike_attributes , cluster_attributes
899+
900+ def _get_spike_sorting_collection (self , spike_sorter = 'pykilosort' , revision = None ):
901+ """
902+ Filters a list or array of collections to get the relevant spike sorting dataset
903+ if there is a pykilosort, load it
904+ """
905+ collection = next (filter (lambda c : c == f'alf/{ self .pname } /{ spike_sorter } ' , self .collections ), None )
906+ # otherwise, prefers the shortest
907+ collection = collection or next (iter (sorted (filter (lambda c : f'alf/{ self .pname } ' in c , self .collections ), key = len )), None )
908+ _logger .debug (f"selecting: { collection } to load amongst candidates: { self .collections } " )
909+ return collection
910+
911+ def _download_spike_sorting_object (self , obj , spike_sorter = 'pykilosort' , dataset_types = None ):
912+ if len (self .collections ) == 0 :
913+ return {}, {}, {}
914+ self .collection = self ._get_spike_sorting_collection (spike_sorter = spike_sorter )
915+ spike_attributes , cluster_attributes = self ._get_attributes (dataset_types )
916+ attributes = {'spikes' : spike_attributes , 'clusters' : cluster_attributes , 'channels' : None }
917+ self .files [obj ] = self .one .load_object (self .eid , obj = obj , attribute = attributes [obj ],
918+ collection = self .collection , download_only = True )
919+
920+ def download_spike_sorting (self , ** kwargs ):
921+ """spike_sorter='pykilosort', dataset_types=None"""
922+ for obj in ['spikes' , 'clusters' , 'channels' ]:
923+ self ._download_spike_sorting_object (obj = obj , ** kwargs )
924+ self .spike_sorting_path = self .files ['spikes' ][0 ].parent
925+
926+ def load_spike_sorting (self , ** kwargs ):
927+ """spike_sorter='pykilosort', dataset_types=None"""
928+ if len (self .collections ) == 0 :
929+ return {}, {}, {}
930+ self .download_spike_sorting (** kwargs )
931+ channels = alfio .load_object (self .files ['channels' ], wildcards = self .one .wildcards )
932+ clusters = alfio .load_object (self .files ['clusters' ], wildcards = self .one .wildcards )
933+ spikes = alfio .load_object (self .files ['spikes' ], wildcards = self .one .wildcards )
934+ if 'brainLocationIds_ccf_2017' not in channels :
935+ channels , self .histology = _load_channel_locations_traj (
936+ self .eid , probe = self .pname , one = self .one , brain_atlas = self .atlas , return_source = True )
937+ channels = channels [self .pname ]
938+ else :
939+ channels = _channels_alf2bunch (channels , brain_regions = self .atlas .regions )
940+ self .histology = 'alf'
941+ return spikes , clusters , channels
942+
943+ @staticmethod
944+ def merge_clusters (spikes , clusters , channels , cache_dir = None ):
945+ """merge metrics and channels info - optionally saves a clusters.pqt dataframe"""
946+ if spikes == {}:
947+ return
948+ nc = clusters ['channels' ].size
949+ # recompute metrics if they are not available
950+ metrics = None
951+ if 'metrics' in clusters :
952+ metrics = clusters .pop ('metrics' )
953+ if metrics .shape [0 ] != nc :
954+ metrics = None
955+ if metrics is None :
956+ _logger .debug ("recompute clusters metrics" )
957+ metrics = pd .DataFrame (quick_unit_metrics (
958+ spikes ['clusters' ], spikes ['times' ], spikes ['amps' ], spikes ['depths' ], cluster_ids = np .arange (nc )))
959+ if isinstance (cache_dir , Path ):
960+ metrics .to_parquet (Path (cache_dir ).joinpath ('clusters.metrics.pqt' ))
961+ for k in metrics .keys ():
962+ clusters [k ] = metrics [k ].to_numpy ()
963+
964+ for k in channels .keys ():
965+ clusters [k ] = channels [k ][clusters ['channels' ]]
966+ if cache_dir :
967+ pd .DataFrame (clusters ).to_parquet (Path (cache_dir ).joinpath ('clusters.pqt' ))
968+ return clusters
969+
970+ @property
971+ def url (self ):
972+ """Gets flatiron URL for the session"""
973+ return str (self .session_path ).replace (str (self .one .alyx .cache_dir ), 'https://ibl.flatironinstitute.org' )
0 commit comments