Skip to content

Commit 22fa715

Browse files
authored
Merge pull request #1002 from int-brain-lab/kwargsOnly
SessionLoader now takes kwargs only; removed old spike sorting loaders; ensure dataset not protected before removing
2 parents 3c00184 + adacfdf commit 22fa715

File tree

8 files changed

+36
-318
lines changed

8 files changed

+36
-318
lines changed

brainbox/examples/raster_depths.py

Lines changed: 0 additions & 39 deletions
This file was deleted.

brainbox/io/one.py

Lines changed: 12 additions & 199 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import gc
44
import logging
55
import re
6-
import os
76
from pathlib import Path
87
from collections import defaultdict
98

@@ -13,7 +12,7 @@
1312
import matplotlib.pyplot as plt
1413

1514
from one.api import ONE, One
16-
from one.alf.path import get_alf_path, full_path_parts, filename_parts
15+
from one.alf.path import get_alf_path, ALFPath
1716
from one.alf.exceptions import ALFObjectNotFound, ALFMultipleCollectionsFound
1817
from one.alf import cache
1918
import one.alf.io as alfio
@@ -410,191 +409,6 @@ def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas
410409
return channels
411410

412411

413-
def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None,
414-
brain_regions=None, nested=True, collection=None, return_collection=False):
415-
"""
416-
From an eid, loads spikes and clusters for all probes
417-
The following set of dataset types are loaded:
418-
'clusters.channels',
419-
'clusters.depths',
420-
'clusters.metrics',
421-
'spikes.clusters',
422-
'spikes.times',
423-
'probes.description'
424-
:param eid: experiment UUID or pathlib.Path of the local session
425-
:param one: an instance of OneAlyx
426-
:param probe: name of probe to load in, if not given all probes for session will be loaded
427-
:param dataset_types: additional spikes/clusters objects to add to the standard default list
428-
:param spike_sorter: name of the spike sorting you want to load (None for default)
429-
:param collection: name of the spike sorting collection to load - exclusive with spike sorter name ex: "alf/probe00"
430-
:param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided
431-
:param nested: if a single probe is required, do not output a dictionary with the probe name as key
432-
:param return_collection: (False) if True, will return the collection used to load
433-
:return: spikes, clusters, channels (dict of bunch, 1 bunch per probe)
434-
"""
435-
_logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting_fast will be removed in future versions.'
436-
'Use brainbox.io.one.SpikeSortingLoader instead')
437-
if collection is None:
438-
collection = _collection_filter_from_args(probe, spike_sorter)
439-
_logger.debug(f"load spike sorting with collection filter {collection}")
440-
kwargs = dict(eid=eid, one=one, collection=collection, revision=revision, dataset_types=dataset_types,
441-
brain_regions=brain_regions)
442-
spikes, clusters, channels = _load_spike_sorting(**kwargs, return_channels=True)
443-
clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None)
444-
if nested is False and len(spikes.keys()) == 1:
445-
k = list(spikes.keys())[0]
446-
channels = channels[k]
447-
clusters = clusters[k]
448-
spikes = spikes[k]
449-
if return_collection:
450-
return spikes, clusters, channels, collection
451-
else:
452-
return spikes, clusters, channels
453-
454-
455-
def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None,
456-
brain_regions=None, return_collection=False):
457-
"""
458-
From an eid, loads spikes and clusters for all probes
459-
The following set of dataset types are loaded:
460-
'clusters.channels',
461-
'clusters.depths',
462-
'clusters.metrics',
463-
'spikes.clusters',
464-
'spikes.times',
465-
'probes.description'
466-
:param eid: experiment UUID or pathlib.Path of the local session
467-
:param one: an instance of OneAlyx
468-
:param probe: name of probe to load in, if not given all probes for session will be loaded
469-
:param dataset_types: additional spikes/clusters objects to add to the standard default list
470-
:param spike_sorter: name of the spike sorting you want to load (None for default)
471-
:param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided
472-
:param return_collection:(bool - False) if True, returns the collection for loading the data
473-
:return: spikes, clusters (dict of bunch, 1 bunch per probe)
474-
"""
475-
_logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.'
476-
'Use brainbox.io.one.SpikeSortingLoader instead')
477-
collection = _collection_filter_from_args(probe, spike_sorter)
478-
_logger.debug(f"load spike sorting with collection filter {collection}")
479-
spikes, clusters = _load_spike_sorting(eid=eid, one=one, collection=collection, revision=revision,
480-
return_channels=False, dataset_types=dataset_types,
481-
brain_regions=brain_regions)
482-
if return_collection:
483-
return spikes, clusters, collection
484-
else:
485-
return spikes, clusters
486-
487-
488-
def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, dataset_types=None,
489-
spike_sorter=None, brain_atlas=None, nested=True, return_collection=False):
490-
"""
491-
For a given eid, get spikes, clusters and channels information, and merges clusters
492-
and channels information before returning all three variables.
493-
494-
Parameters
495-
----------
496-
eid : [str, UUID, Path, dict]
497-
Experiment session identifier; may be a UUID, URL, experiment reference string
498-
details dict or Path
499-
one : one.api.OneAlyx
500-
An instance of ONE (shouldn't be in 'local' mode)
501-
probe : [str, list of str]
502-
The probe label(s), e.g. 'probe01'
503-
aligned : bool
504-
Whether to get the latest user aligned channel when not resolved or use histology track
505-
dataset_types : list of str
506-
Optional additional spikes/clusters objects to add to the standard default list
507-
spike_sorter : str
508-
Name of the spike sorting you want to load (None for default which is pykilosort if it's
509-
available otherwise the default MATLAB kilosort)
510-
brain_atlas : iblatlas.atlas.BrainAtlas
511-
Brain atlas object (default: Allen atlas)
512-
return_collection: bool
513-
Returns an extra argument with the collection chosen
514-
515-
Returns
516-
-------
517-
spikes : dict of one.alf.io.AlfBunch
518-
A dict with probe labels as keys, contains bunch(es) of spike data for the provided
519-
session and spike sorter, with keys ('clusters', 'times')
520-
clusters : dict of one.alf.io.AlfBunch
521-
A dict with probe labels as keys, contains bunch(es) of cluster data, with keys
522-
('channels', 'depths', 'metrics')
523-
channels : dict of one.alf.io.AlfBunch
524-
A dict with probe labels as keys, contains channel locations with keys ('acronym',
525-
'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized.
526-
"""
527-
# --- Get spikes and clusters data
528-
_logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.'
529-
'Use brainbox.io.one.SpikeSortingLoader instead')
530-
one = one or ONE()
531-
brain_atlas = brain_atlas or AllenAtlas()
532-
spikes, clusters, collection = load_spike_sorting(
533-
eid, one=one, probe=probe, dataset_types=dataset_types, spike_sorter=spike_sorter, return_collection=True)
534-
# -- Get brain regions and assign to clusters
535-
channels = load_channel_locations(eid, one=one, probe=probe, aligned=aligned,
536-
brain_atlas=brain_atlas)
537-
clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None)
538-
if nested is False and len(spikes.keys()) == 1:
539-
k = list(spikes.keys())[0]
540-
channels = channels[k]
541-
clusters = clusters[k]
542-
spikes = spikes[k]
543-
if return_collection:
544-
return spikes, clusters, channels, collection
545-
else:
546-
return spikes, clusters, channels
547-
548-
549-
def load_ephys_session(eid, one=None):
550-
"""
551-
From an eid, hits the Alyx database and downloads a standard default set of dataset types
552-
From a local session Path (pathlib.Path), loads a standard default set of dataset types
553-
to perform analysis:
554-
'clusters.channels',
555-
'clusters.depths',
556-
'clusters.metrics',
557-
'spikes.clusters',
558-
'spikes.times',
559-
'probes.description'
560-
561-
Parameters
562-
----------
563-
eid : [str, UUID, Path, dict]
564-
Experiment session identifier; may be a UUID, URL, experiment reference string
565-
details dict or Path
566-
one : oneibl.one.OneAlyx, optional
567-
ONE object to use for loading. Will generate internal one if not used, by default None
568-
569-
Returns
570-
-------
571-
spikes : dict of one.alf.io.AlfBunch
572-
A dict with probe labels as keys, contains bunch(es) of spike data for the provided
573-
session and spike sorter, with keys ('clusters', 'times')
574-
clusters : dict of one.alf.io.AlfBunch
575-
A dict with probe labels as keys, contains bunch(es) of cluster data, with keys
576-
('channels', 'depths', 'metrics')
577-
trials : one.alf.io.AlfBunch of numpy.ndarray
578-
The session trials data
579-
"""
580-
assert one
581-
spikes, clusters = load_spike_sorting(eid, one=one)
582-
trials = one.load_object(eid, 'trials')
583-
return spikes, clusters, trials
584-
585-
586-
def _remove_old_clusters(session_path, probe):
587-
# gets clusters and spikes from a local session folder
588-
probe_path = session_path.joinpath('alf', probe)
589-
590-
# look for clusters.metrics.csv file, if it exists delete as we now have .pqt file instead
591-
cluster_file = probe_path.joinpath('clusters.metrics.csv')
592-
593-
if cluster_file.exists():
594-
os.remove(cluster_file)
595-
_logger.info('Deleting old clusters.metrics.csv file')
596-
597-
598412
def merge_clusters_channels(dic_clus, channels, keys_to_add_extra=None):
599413
"""
600414
Takes (default and any extra) values in given keys from channels and assign them to clusters.
@@ -785,7 +599,7 @@ class SpikeSortingLoader:
785599
This class can be instantiated in several manners
786600
- With Alyx database probe id:
787601
SpikeSortingLoader(pid=pid, one=one)
788-
- With Alyx database eic and probe name:
602+
- With Alyx database eid and probe name:
789603
SpikeSortingLoader(eid=eid, pname='probe00', one=one)
790604
- From a local session and probe name:
791605
SpikeSortingLoader(session_path=session_path, pname='probe00')
@@ -796,7 +610,7 @@ class SpikeSortingLoader:
796610
pid: str = None
797611
eid: str = ''
798612
pname: str = ''
799-
session_path: Path = ''
613+
session_path: ALFPath = ''
800614
# the following properties are the outcome of the post init function
801615
collections: list = None
802616
datasets: list = None # list of all datasets belonging to the session
@@ -825,6 +639,7 @@ def __post_init__(self):
825639
self.session_path = self.one.eid2path(self.eid)
826640
# fully local providing a session path
827641
else:
642+
self.session_path = ALFPath(self.session_path) # Ensure session_path is an ALFPath object
828643
if self.one:
829644
self.eid = self.one.to_eid(self.session_path)
830645
else:
@@ -1048,11 +863,9 @@ def filter_files_by_namespace(all_files, namespace):
1048863
namespace_files = defaultdict(dict)
1049864
available_namespaces = []
1050865
for file in all_files:
1051-
fparts = filename_parts(file.name, as_dict=True)
1052-
fname = f"{fparts['object']}.{fparts['attribute']}"
1053-
nspace = fparts['namespace']
866+
nspace = file.namespace or None
1054867
available_namespaces.append(nspace)
1055-
namespace_files[fname][nspace] = file
868+
namespace_files[f"{file.object}.{file.attribute}"][nspace] = file
1056869

1057870
if namespace not in set(available_namespaces):
1058871
_logger.info(f'Could not find manual curation results for {namespace}, returning default'
@@ -1124,7 +937,7 @@ def _assert_version_consistency(self):
1124937
assert fn.relative_to(self.session_path).parts[2] == self.spike_sorter, \
1125938
f"You required strict version {self.spike_sorter}, {fn} does not match"
1126939
if self.revision:
1127-
assert full_path_parts(fn)[5] == self.revision, \
940+
assert fn.revision == self.revision, \
1128941
f"You required strict revision {self.revision}, {fn} does not match"
1129942

1130943
@staticmethod
@@ -1171,7 +984,7 @@ def merge_clusters(spikes, clusters, channels, cache_dir=None, compute_metrics=F
1171984

1172985
@property
1173986
def url(self):
1174-
"""Gets flatiron URL for the session"""
987+
"""Gets flatiron URL for the session."""
1175988
webclient = getattr(self.one, '_web_client', None)
1176989
return webclient.rel_path2url(get_alf_path(self.session_path)) if webclient else None
1177990

@@ -1318,7 +1131,7 @@ def plot_rawdata_snippet(self, sr, spikes, clusters, t0,
13181131
return fig, axs
13191132

13201133

1321-
@dataclass
1134+
@dataclass(kw_only=True)
13221135
class SessionLoader:
13231136
"""
13241137
Object to load session data for a give session in the recommended way.
@@ -1386,7 +1199,7 @@ class SessionLoader:
13861199
>>> sess_loader.load_wheel(sampling_rate=100)
13871200
"""
13881201
one: One = None
1389-
session_path: Path = ''
1202+
session_path: ALFPath = ''
13901203
eid: str = ''
13911204
revision: str = ''
13921205
data_info: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
@@ -1407,7 +1220,7 @@ def __post_init__(self):
14071220
# If session path is given, takes precedence over eid
14081221
if self.session_path is not None and self.session_path != '':
14091222
self.eid = self.one.to_eid(self.session_path)
1410-
self.session_path = Path(self.session_path)
1223+
self.session_path = ALFPath(self.session_path)
14111224
# Providing no session path, try to infer from eid
14121225
else:
14131226
if self.eid is not None and self.eid != '':
@@ -1493,7 +1306,7 @@ def _find_behaviour_collection(self, obj):
14931306
if len(dsets) == 0:
14941307
return 'alf'
14951308
else:
1496-
collections = [full_path_parts(self.session_path.joinpath(d), as_dict=True)['collection'] for d in dsets]
1309+
collections = [x.collection for x in map(self.session_path.joinpath, dsets)]
14971310
if len(set(collections)) == 1:
14981311
return collections[0]
14991312
else:

0 commit comments

Comments
 (0)