Skip to content

Commit 8fe4acf

Browse files
committed
brainbox io: spike sorting loader can return collection
1 parent 99fbc0f commit 8fe4acf

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

brainbox/io/one.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,6 @@ def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, align
276276
collections = one.list_collections(eid, filename='channels*', collection=collection,
277277
revision=revision)
278278
probe_collection = _get_spike_sorting_collection(collections, probe)
279-
print(probe_collection)
280279
chn_coords = one.load_dataset(eid, 'channels.localCoordinates', collection=probe_collection)
281280
depths = chn_coords[:, 1]
282281

@@ -388,7 +387,7 @@ def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas
388387

389388

390389
def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None,
391-
brain_regions=None, nested=True, collection=None):
390+
brain_regions=None, nested=True, collection=None, return_collection=False):
392391
"""
393392
From an eid, loads spikes and clusters for all probes
394393
The following set of dataset types are loaded:
@@ -407,7 +406,8 @@ def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike
407406
:param return_channels: (bool) defaults to False otherwise tries and load channels from disk
408407
:param brain_regions: ibllib.atlas.regions.BrainRegions object - will label acronyms if provided
409408
:param nested: if a single probe is required, do not output a dictionary with the probe name as key
410-
:return: spikes, clusters (dict of bunch, 1 bunch per probe)
409+
:param return_collection: (False) if True, will return the collection used to load
410+
:return: spikes, clusters, channels (dict of bunch, 1 bunch per probe)
411411
"""
412412
if collection is None:
413413
collection = _collection_filter_from_args(probe, spike_sorter)
@@ -421,7 +421,10 @@ def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike
421421
channels = channels[k]
422422
clusters = clusters[k]
423423
spikes = spikes[k]
424-
return spikes, clusters, channels
424+
if return_collection:
425+
return spikes, clusters, channels, collection
426+
else:
427+
return spikes, clusters, channels
425428

426429

427430
def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None,

0 commit comments

Comments
 (0)