33import gc
44import logging
55import re
6- import os
76from pathlib import Path
87from collections import defaultdict
98
1312import matplotlib .pyplot as plt
1413
1514from one .api import ONE , One
16- from one .alf .path import get_alf_path , full_path_parts , filename_parts
15+ from one .alf .path import get_alf_path , ALFPath
1716from one .alf .exceptions import ALFObjectNotFound , ALFMultipleCollectionsFound
1817from one .alf import cache
1918import one .alf .io as alfio
@@ -410,191 +409,6 @@ def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas
410409 return channels
411410
412411
413- def load_spike_sorting_fast (eid , one = None , probe = None , dataset_types = None , spike_sorter = None , revision = None ,
414- brain_regions = None , nested = True , collection = None , return_collection = False ):
415- """
416- From an eid, loads spikes and clusters for all probes
417- The following set of dataset types are loaded:
418- 'clusters.channels',
419- 'clusters.depths',
420- 'clusters.metrics',
421- 'spikes.clusters',
422- 'spikes.times',
423- 'probes.description'
424- :param eid: experiment UUID or pathlib.Path of the local session
425- :param one: an instance of OneAlyx
426- :param probe: name of probe to load in, if not given all probes for session will be loaded
427- :param dataset_types: additional spikes/clusters objects to add to the standard default list
428- :param spike_sorter: name of the spike sorting you want to load (None for default)
429- :param collection: name of the spike sorting collection to load - exclusive with spike sorter name ex: "alf/probe00"
430- :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided
431- :param nested: if a single probe is required, do not output a dictionary with the probe name as key
432- :param return_collection: (False) if True, will return the collection used to load
433- :return: spikes, clusters, channels (dict of bunch, 1 bunch per probe)
434- """
435- _logger .warning ('Deprecation warning: brainbox.io.one.load_spike_sorting_fast will be removed in future versions.'
436- 'Use brainbox.io.one.SpikeSortingLoader instead' )
437- if collection is None :
438- collection = _collection_filter_from_args (probe , spike_sorter )
439- _logger .debug (f"load spike sorting with collection filter { collection } " )
440- kwargs = dict (eid = eid , one = one , collection = collection , revision = revision , dataset_types = dataset_types ,
441- brain_regions = brain_regions )
442- spikes , clusters , channels = _load_spike_sorting (** kwargs , return_channels = True )
443- clusters = merge_clusters_channels (clusters , channels , keys_to_add_extra = None )
444- if nested is False and len (spikes .keys ()) == 1 :
445- k = list (spikes .keys ())[0 ]
446- channels = channels [k ]
447- clusters = clusters [k ]
448- spikes = spikes [k ]
449- if return_collection :
450- return spikes , clusters , channels , collection
451- else :
452- return spikes , clusters , channels
453-
454-
455- def load_spike_sorting (eid , one = None , probe = None , dataset_types = None , spike_sorter = None , revision = None ,
456- brain_regions = None , return_collection = False ):
457- """
458- From an eid, loads spikes and clusters for all probes
459- The following set of dataset types are loaded:
460- 'clusters.channels',
461- 'clusters.depths',
462- 'clusters.metrics',
463- 'spikes.clusters',
464- 'spikes.times',
465- 'probes.description'
466- :param eid: experiment UUID or pathlib.Path of the local session
467- :param one: an instance of OneAlyx
468- :param probe: name of probe to load in, if not given all probes for session will be loaded
469- :param dataset_types: additional spikes/clusters objects to add to the standard default list
470- :param spike_sorter: name of the spike sorting you want to load (None for default)
471- :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided
472- :param return_collection:(bool - False) if True, returns the collection for loading the data
473- :return: spikes, clusters (dict of bunch, 1 bunch per probe)
474- """
475- _logger .warning ('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.'
476- 'Use brainbox.io.one.SpikeSortingLoader instead' )
477- collection = _collection_filter_from_args (probe , spike_sorter )
478- _logger .debug (f"load spike sorting with collection filter { collection } " )
479- spikes , clusters = _load_spike_sorting (eid = eid , one = one , collection = collection , revision = revision ,
480- return_channels = False , dataset_types = dataset_types ,
481- brain_regions = brain_regions )
482- if return_collection :
483- return spikes , clusters , collection
484- else :
485- return spikes , clusters
486-
487-
488- def load_spike_sorting_with_channel (eid , one = None , probe = None , aligned = False , dataset_types = None ,
489- spike_sorter = None , brain_atlas = None , nested = True , return_collection = False ):
490- """
491- For a given eid, get spikes, clusters and channels information, and merges clusters
492- and channels information before returning all three variables.
493-
494- Parameters
495- ----------
496- eid : [str, UUID, Path, dict]
497- Experiment session identifier; may be a UUID, URL, experiment reference string
498- details dict or Path
499- one : one.api.OneAlyx
500- An instance of ONE (shouldn't be in 'local' mode)
501- probe : [str, list of str]
502- The probe label(s), e.g. 'probe01'
503- aligned : bool
504- Whether to get the latest user aligned channel when not resolved or use histology track
505- dataset_types : list of str
506- Optional additional spikes/clusters objects to add to the standard default list
507- spike_sorter : str
508- Name of the spike sorting you want to load (None for default which is pykilosort if it's
509- available otherwise the default MATLAB kilosort)
510- brain_atlas : iblatlas.atlas.BrainAtlas
511- Brain atlas object (default: Allen atlas)
512- return_collection: bool
513- Returns an extra argument with the collection chosen
514-
515- Returns
516- -------
517- spikes : dict of one.alf.io.AlfBunch
518- A dict with probe labels as keys, contains bunch(es) of spike data for the provided
519- session and spike sorter, with keys ('clusters', 'times')
520- clusters : dict of one.alf.io.AlfBunch
521- A dict with probe labels as keys, contains bunch(es) of cluster data, with keys
522- ('channels', 'depths', 'metrics')
523- channels : dict of one.alf.io.AlfBunch
524- A dict with probe labels as keys, contains channel locations with keys ('acronym',
525- 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized.
526- """
527- # --- Get spikes and clusters data
528- _logger .warning ('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.'
529- 'Use brainbox.io.one.SpikeSortingLoader instead' )
530- one = one or ONE ()
531- brain_atlas = brain_atlas or AllenAtlas ()
532- spikes , clusters , collection = load_spike_sorting (
533- eid , one = one , probe = probe , dataset_types = dataset_types , spike_sorter = spike_sorter , return_collection = True )
534- # -- Get brain regions and assign to clusters
535- channels = load_channel_locations (eid , one = one , probe = probe , aligned = aligned ,
536- brain_atlas = brain_atlas )
537- clusters = merge_clusters_channels (clusters , channels , keys_to_add_extra = None )
538- if nested is False and len (spikes .keys ()) == 1 :
539- k = list (spikes .keys ())[0 ]
540- channels = channels [k ]
541- clusters = clusters [k ]
542- spikes = spikes [k ]
543- if return_collection :
544- return spikes , clusters , channels , collection
545- else :
546- return spikes , clusters , channels
547-
548-
549- def load_ephys_session (eid , one = None ):
550- """
551- From an eid, hits the Alyx database and downloads a standard default set of dataset types
552- From a local session Path (pathlib.Path), loads a standard default set of dataset types
553- to perform analysis:
554- 'clusters.channels',
555- 'clusters.depths',
556- 'clusters.metrics',
557- 'spikes.clusters',
558- 'spikes.times',
559- 'probes.description'
560-
561- Parameters
562- ----------
563- eid : [str, UUID, Path, dict]
564- Experiment session identifier; may be a UUID, URL, experiment reference string
565- details dict or Path
566- one : oneibl.one.OneAlyx, optional
567- ONE object to use for loading. Will generate internal one if not used, by default None
568-
569- Returns
570- -------
571- spikes : dict of one.alf.io.AlfBunch
572- A dict with probe labels as keys, contains bunch(es) of spike data for the provided
573- session and spike sorter, with keys ('clusters', 'times')
574- clusters : dict of one.alf.io.AlfBunch
575- A dict with probe labels as keys, contains bunch(es) of cluster data, with keys
576- ('channels', 'depths', 'metrics')
577- trials : one.alf.io.AlfBunch of numpy.ndarray
578- The session trials data
579- """
580- assert one
581- spikes , clusters = load_spike_sorting (eid , one = one )
582- trials = one .load_object (eid , 'trials' )
583- return spikes , clusters , trials
584-
585-
586- def _remove_old_clusters (session_path , probe ):
587- # gets clusters and spikes from a local session folder
588- probe_path = session_path .joinpath ('alf' , probe )
589-
590- # look for clusters.metrics.csv file, if it exists delete as we now have .pqt file instead
591- cluster_file = probe_path .joinpath ('clusters.metrics.csv' )
592-
593- if cluster_file .exists ():
594- os .remove (cluster_file )
595- _logger .info ('Deleting old clusters.metrics.csv file' )
596-
597-
598412def merge_clusters_channels (dic_clus , channels , keys_to_add_extra = None ):
599413 """
600414 Takes (default and any extra) values in given keys from channels and assign them to clusters.
@@ -785,7 +599,7 @@ class SpikeSortingLoader:
785599 This class can be instantiated in several manners
786600 - With Alyx database probe id:
787601 SpikeSortingLoader(pid=pid, one=one)
788- - With Alyx database eic and probe name:
602+ - With Alyx database eid and probe name:
789603 SpikeSortingLoader(eid=eid, pname='probe00', one=one)
790604 - From a local session and probe name:
791605 SpikeSortingLoader(session_path=session_path, pname='probe00')
@@ -796,7 +610,7 @@ class SpikeSortingLoader:
796610 pid : str = None
797611 eid : str = ''
798612 pname : str = ''
799- session_path : Path = ''
613+ session_path : ALFPath = ''
800614 # the following properties are the outcome of the post init function
801615 collections : list = None
802616 datasets : list = None # list of all datasets belonging to the session
@@ -825,6 +639,7 @@ def __post_init__(self):
825639 self .session_path = self .one .eid2path (self .eid )
826640 # fully local providing a session path
827641 else :
642+ self .session_path = ALFPath (self .session_path ) # Ensure session_path is an ALFPath object
828643 if self .one :
829644 self .eid = self .one .to_eid (self .session_path )
830645 else :
@@ -1048,11 +863,9 @@ def filter_files_by_namespace(all_files, namespace):
1048863 namespace_files = defaultdict (dict )
1049864 available_namespaces = []
1050865 for file in all_files :
1051- fparts = filename_parts (file .name , as_dict = True )
1052- fname = f"{ fparts ['object' ]} .{ fparts ['attribute' ]} "
1053- nspace = fparts ['namespace' ]
866+ nspace = file .namespace or None
1054867 available_namespaces .append (nspace )
1055- namespace_files [fname ][nspace ] = file
868+ namespace_files [f" { file . object } . { file . attribute } " ][nspace ] = file
1056869
1057870 if namespace not in set (available_namespaces ):
1058871 _logger .info (f'Could not find manual curation results for { namespace } , returning default'
@@ -1124,7 +937,7 @@ def _assert_version_consistency(self):
1124937 assert fn .relative_to (self .session_path ).parts [2 ] == self .spike_sorter , \
1125938 f"You required strict version { self .spike_sorter } , { fn } does not match"
1126939 if self .revision :
1127- assert full_path_parts ( fn )[ 5 ] == self .revision , \
940+ assert fn . revision == self .revision , \
1128941 f"You required strict revision { self .revision } , { fn } does not match"
1129942
1130943 @staticmethod
@@ -1171,7 +984,7 @@ def merge_clusters(spikes, clusters, channels, cache_dir=None, compute_metrics=F
1171984
1172985 @property
1173986 def url (self ):
1174- """Gets flatiron URL for the session"""
987+ """Gets flatiron URL for the session. """
1175988 webclient = getattr (self .one , '_web_client' , None )
1176989 return webclient .rel_path2url (get_alf_path (self .session_path )) if webclient else None
1177990
@@ -1318,7 +1131,7 @@ def plot_rawdata_snippet(self, sr, spikes, clusters, t0,
13181131 return fig , axs
13191132
13201133
1321- @dataclass
1134+ @dataclass ( kw_only = True )
13221135class SessionLoader :
13231136 """
13241137 Object to load session data for a give session in the recommended way.
@@ -1386,7 +1199,7 @@ class SessionLoader:
13861199 >>> sess_loader.load_wheel(sampling_rate=100)
13871200 """
13881201 one : One = None
1389- session_path : Path = ''
1202+ session_path : ALFPath = ''
13901203 eid : str = ''
13911204 revision : str = ''
13921205 data_info : pd .DataFrame = field (default_factory = pd .DataFrame , repr = False )
@@ -1407,7 +1220,7 @@ def __post_init__(self):
14071220 # If session path is given, takes precedence over eid
14081221 if self .session_path is not None and self .session_path != '' :
14091222 self .eid = self .one .to_eid (self .session_path )
1410- self .session_path = Path (self .session_path )
1223+ self .session_path = ALFPath (self .session_path )
14111224 # Providing no session path, try to infer from eid
14121225 else :
14131226 if self .eid is not None and self .eid != '' :
@@ -1493,7 +1306,7 @@ def _find_behaviour_collection(self, obj):
14931306 if len (dsets ) == 0 :
14941307 return 'alf'
14951308 else :
1496- collections = [full_path_parts ( self . session_path . joinpath ( d ), as_dict = True )[ ' collection' ] for d in dsets ]
1309+ collections = [x . collection for x in map ( self . session_path . joinpath , dsets ) ]
14971310 if len (set (collections )) == 1 :
14981311 return collections [0 ]
14991312 else :
0 commit comments