@@ -108,7 +108,8 @@ def _channels_alf2bunch(channels, brain_regions=None):
108108 return channels_
109109
110110
111- def _load_spike_sorting (eid , one = None , collection = None , revision = None , return_channels = True , dataset_types = None ):
111+ def _load_spike_sorting (eid , one = None , collection = None , revision = None , return_channels = True , dataset_types = None ,
112+ brain_regions = None ):
112113 """
113114 Generic function to load spike sortin according to one searchwords
114115 Will try to load one spike sorting for any probe present for the eid matching the collection
@@ -121,6 +122,7 @@ def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_ch
121122 :param collection: collection filter word - accepts wildcard - can be a combination of spike sorter and probe
122123 :param revision: revision to load
123124 :param return_channels: True
125+ :param brain_regions: ibllib.atlas.regions.BrainRegions object - will label acronyms if provided
124126 :return:
125127 """
126128 one = one or ONE ()
@@ -140,7 +142,8 @@ def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_ch
140142 clusters [pname ] = one .load_object (eid , collection = probe_collection , obj = 'clusters' ,
141143 attribute = cluster_attributes )
142144
143- channels = _load_channels_locations_from_disk (eid , collection = collection , one = one , revision = revision )
145+ channels = _load_channels_locations_from_disk (eid , collection = collection , one = one , revision = revision ,
146+ brain_regions = brain_regions )
144147
145148 if return_channels :
146149 return spikes , clusters , channels
@@ -179,31 +182,42 @@ def _load_channels_locations_from_disk(eid, collection=None, one=None, revision=
179182 _logger .debug (f"looking for a resolved alignment dataset in { aligned_channel_collections } " )
180183 ac_collection = _get_spike_sorting_collection (aligned_channel_collections , probe )
181184 channels_aligned = one .load_object (eid , 'channels' , collection = ac_collection )
182- # oftentimes the channel map for different spike sorters may be different so interpolate the alignment onto
183- nch = channels [probe ]['localCoordinates' ].shape [0 ]
184- # if there is no spike sorting in the base folder, the alignment doesn't have the localCoordinates field
185- # so we reconstruct from the Neuropixel map. This only happens for early pykilosort sorts
186- if 'localCoordinates' in channels_aligned .keys ():
187- aligned_depths = channels_aligned ['localCoordinates' ][:, 1 ]
188- else :
189- assert channels_aligned ['mlapdv' ].shape [0 ] == 384
190- NEUROPIXEL_VERSION = 1
191- from ibllib .ephys .neuropixel import trace_header
192- aligned_depths = trace_header (version = NEUROPIXEL_VERSION )['y' ]
193- depth_aligned , ind_aligned = np .unique (aligned_depths , return_index = True )
194- depths , ind , iinv = np .unique (channels [probe ]['localCoordinates' ][:, 1 ], return_index = True , return_inverse = True )
195- channels [probe ]['mlapdv' ] = np .zeros ((nch , 3 ))
196- for i in np .arange (3 ):
197- channels [probe ]['mlapdv' ][:, i ] = np .interp (
198- depths , depth_aligned , channels_aligned ['mlapdv' ][ind_aligned , i ])[iinv ]
199- # the brain locations have to be interpolated by nearest neighbour
200- fcn_interp = interp1d (depth_aligned , channels_aligned ['brainLocationIds_ccf_2017' ][ind_aligned ], kind = 'nearest' )
201- channels [probe ]['brainLocationIds_ccf_2017' ] = fcn_interp (depths )[iinv ].astype (np .int32 )
185+ channels [probe ] = channel_locations_interpolation (channels_aligned , channels [probe ])
202186 # only have to reformat channels if we were able to load coordinates from disk
203187 channels [probe ] = _channels_alf2bunch (channels [probe ], brain_regions = brain_regions )
204188 return channels
205189
206190
191+ def channel_locations_interpolation (channels_aligned , channels ):
192+ """
193+ oftentimes the channel map for different spike sorters may be different so interpolate the alignment onto
194+ if there is no spike sorting in the base folder, the alignment doesn't have the localCoordinates field
195+ so we reconstruct from the Neuropixel map. This only happens for early pykilosort sorts
196+ :param channels_aligned: Bunch or dictionary of aligned channels containing at least keys
197+ 'mlapdv' and 'brainLocationIds_ccf_2017' - those are the guide for the interpolation
198+ :param channels: Bunch or dictionary of aligned channels containing at least keys 'localCoordinates'
199+ :return: Bunch or dictionary of channels with extra keys 'mlapdv' and 'brainLocationIds_ccf_2017'
200+ """
201+ nch = channels ['localCoordinates' ].shape [0 ]
202+ if 'localCoordinates' in channels_aligned .keys ():
203+ aligned_depths = channels_aligned ['localCoordinates' ][:, 1 ]
204+ else :
205+ assert channels_aligned ['mlapdv' ].shape [0 ] == 384
206+ NEUROPIXEL_VERSION = 1
207+ from ibllib .ephys .neuropixel import trace_header
208+ aligned_depths = trace_header (version = NEUROPIXEL_VERSION )['y' ]
209+ depth_aligned , ind_aligned = np .unique (aligned_depths , return_index = True )
210+ depths , ind , iinv = np .unique (channels ['localCoordinates' ][:, 1 ], return_index = True , return_inverse = True )
211+ channels ['mlapdv' ] = np .zeros ((nch , 3 ))
212+ for i in np .arange (3 ):
213+ channels ['mlapdv' ][:, i ] = np .interp (
214+ depths , depth_aligned , channels_aligned ['mlapdv' ][ind_aligned , i ])[iinv ]
215+ # the brain locations have to be interpolated by nearest neighbour
216+ fcn_interp = interp1d (depth_aligned , channels_aligned ['brainLocationIds_ccf_2017' ][ind_aligned ], kind = 'nearest' )
217+ channels ['brainLocationIds_ccf_2017' ] = fcn_interp (depths )[iinv ].astype (np .int32 )
218+ return channels
219+
220+
207221def _load_channel_locations_traj (eid , probe = None , one = None , revision = None , aligned = False ,
208222 brain_atlas = None ):
209223 print ('from traj' )
@@ -309,17 +323,43 @@ def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas
309323 return channels
310324
311325
312- def load_spike_sorting_fast (eid , probe = None , spike_sorter = None , ** kwargs ):
326+ def load_spike_sorting_fast (eid , one = None , probe = None , dataset_types = None , spike_sorter = None , revision = None ,
327+ brain_regions = None , nested = True ):
313328 """
314- Same as load_spike_sorting but with return_channels=True
329+ From an eid, loads spikes and clusters for all probes
330+ The following set of dataset types are loaded:
331+ 'clusters.channels',
332+ 'clusters.depths',
333+ 'clusters.metrics',
334+ 'spikes.clusters',
335+ 'spikes.times',
336+ 'probes.description'
337+ :param eid: experiment UUID or pathlib.Path of the local session
338+ :param one: an instance of OneAlyx
339+ :param probe: name of probe to load in, if not given all probes for session will be loaded
340+ :param dataset_types: additional spikes/clusters objects to add to the standard default list
341+ :param spike_sorter: name of the spike sorting you want to load (None for default)
342+ :param return_channels: (bool) defaults to False otherwise tries and load channels from disk
343+ :param brain_regions: ibllib.atlas.regions.BrainRegions object - will label acronyms if provided
344+ :param nested: if a single probe is required, do not output a dictionary with the probe name as key
345+ :return: spikes, clusters (dict of bunch, 1 bunch per probe)
315346 """
316347 collection = _collection_filter_from_args (probe , spike_sorter )
317348 _logger .debug (f"load spike sorting with collection filter { collection } " )
318- return _load_spike_sorting (eid , collection = collection , return_channels = True , ** kwargs )
349+ kwargs = dict (eid = eid , one = one , collection = collection , revision = revision , dataset_types = dataset_types ,
350+ brain_regions = brain_regions )
351+ spikes , clusters , channels = _load_spike_sorting (** kwargs , return_channels = True )
352+ clusters = merge_clusters_channels (clusters , channels , keys_to_add_extra = None )
353+ if nested is False :
354+ k = list (spikes .keys ())[0 ]
355+ channels = channels [k ]
356+ clusters = clusters [k ]
357+ spikes = spikes [k ]
358+ return spikes , clusters , channels
319359
320360
321- def load_spike_sorting (eid , one = None , probe = None , dataset_types = None ,
322- spike_sorter = None , revision = None , return_channels = False ):
361+ def load_spike_sorting (eid , one = None , probe = None , dataset_types = None , spike_sorter = None , revision = None ,
362+ brain_regions = None ):
323363 """
324364 From an eid, loads spikes and clusters for all probes
325365 The following set of dataset types are loaded:
@@ -335,12 +375,15 @@ def load_spike_sorting(eid, one=None, probe=None, dataset_types=None,
335375 :param dataset_types: additional spikes/clusters objects to add to the standard default list
336376 :param spike_sorter: name of the spike sorting you want to load (None for default)
337377 :param return_channels: (bool) defaults to False otherwise tries and load channels from disk
338- :return: spikes, clusters, channels (dict of bunch, 1 bunch per probe)
378+ :param brain_regions: ibllib.atlas.regions.BrainRegions object - will label acronyms if provided
379+ :return: spikes, clusters (dict of bunch, 1 bunch per probe)
339380 """
340381 collection = _collection_filter_from_args (probe , spike_sorter )
341382 _logger .debug (f"load spike sorting with collection filter { collection } " )
342- return _load_spike_sorting (eid = eid , one = one , collection = collection , revision = revision ,
343- return_channels = return_channels , dataset_types = dataset_types )
383+ spikes , clusters = _load_spike_sorting (eid = eid , one = one , collection = collection , revision = revision ,
384+ return_channels = False , dataset_types = dataset_types ,
385+ brain_regions = brain_regions )
386+ return spikes , clusters
344387
345388
346389def load_spike_sorting_with_channel (eid , one = None , probe = None , aligned = False , dataset_types = None ,
0 commit comments