1010
1111from one .api import ONE , One
1212import one .alf .io as alfio
13+ from one .alf .files import get_alf_path
1314from one .alf import cache
15+ from neuropixel import SITES_COORDINATES , TIP_SIZE_UM , trace_header
16+ import spikeglx
1417
1518from iblutil .util import Bunch
16- from ibllib .io import spikeglx
1719from ibllib .io .extractors .training_wheel import extract_wheel_moves , extract_first_movement_times
18- from ibllib .ephys .neuropixel import SITES_COORDINATES , TIP_SIZE_UM , trace_header
1920from ibllib .atlas import atlas , AllenAtlas
2021from ibllib .pipes import histology
2122from ibllib .pipes .ephys_alignment import EphysAlignment
@@ -123,7 +124,7 @@ def _channels_alf2bunch(channels, brain_regions=None):
123124
124125
125126def _load_spike_sorting (eid , one = None , collection = None , revision = None , return_channels = True , dataset_types = None ,
126- brain_regions = None , return_collection = False ):
127+ brain_regions = None ):
127128 """
128129 Generic function to load spike sorting according data using ONE.
129130
@@ -168,7 +169,7 @@ def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_ch
168169 collections = one .list_collections (eid , filename = 'spikes*' , collection = collection , revision = revision )
169170 if len (collections ) == 0 :
170171 _logger .warning (f"eid { eid } : no collection found with collection filter: { collection } , revision: { revision } " )
171- pnames = list (set ([ c .split ('/' )[1 ] for c in collections ] ))
172+ pnames = list (set (c .split ('/' )[1 ] for c in collections ))
172173 spikes , clusters , channels = ({} for _ in range (3 ))
173174
174175 spike_attributes , cluster_attributes = _get_attributes (dataset_types )
@@ -246,7 +247,7 @@ def channel_locations_interpolation(channels_aligned, channels=None, brain_regio
246247 if channels is None :
247248 channels = {'localCoordinates' : np .c_ [h ['x' ], h ['y' ]]}
248249 nch = channels ['localCoordinates' ].shape [0 ]
249- if set ([ 'x' , 'y' , 'z' ]) .issubset (set (channels_aligned .keys ())):
250+ if { 'x' , 'y' , 'z' } .issubset (set (channels_aligned .keys ())):
250251 channels_aligned = _channels_bunch2alf (channels_aligned )
251252 if 'localCoordinates' in channels_aligned .keys ():
252253 aligned_depths = channels_aligned ['localCoordinates' ][:, 1 ]
@@ -350,7 +351,7 @@ def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, align
350351 return channels
351352
352353
353- def load_channel_locations (eid , probe = None , one = None , aligned = False , brain_atlas = None , return_source = False ):
354+ def load_channel_locations (eid , probe = None , one = None , aligned = False , brain_atlas = None ):
354355 """
355356 Load the brain locations of each channel for a given session/probe
356357
@@ -367,8 +368,6 @@ def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas
367368 Whether to get the latest user aligned channel when not resolved or use histology track
368369 brain_atlas : ibllib.atlas.BrainAtlas
369370 Brain atlas object (default: Allen atlas)
370- return_source: bool
371- if True returns the source of the channel lcoations (default False)
372371 Returns
373372 -------
374373 dict of one.alf.io.AlfBunch
@@ -412,7 +411,6 @@ def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike
412411 :param dataset_types: additional spikes/clusters objects to add to the standard default list
413412 :param spike_sorter: name of the spike sorting you want to load (None for default)
414413 :param collection: name of the spike sorting collection to load - exclusive with spike sorter name ex: "alf/probe00"
415- :param return_channels: (bool) defaults to False otherwise tries and load channels from disk
416414 :param brain_regions: ibllib.atlas.regions.BrainRegions object - will label acronyms if provided
417415 :param nested: if a single probe is required, do not output a dictionary with the probe name as key
418416 :param return_collection: (False) if True, will return the collection used to load
@@ -454,7 +452,6 @@ def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sort
454452 :param probe: name of probe to load in, if not given all probes for session will be loaded
455453 :param dataset_types: additional spikes/clusters objects to add to the standard default list
456454 :param spike_sorter: name of the spike sorting you want to load (None for default)
457- :param return_channels: (bool) defaults to False otherwise tries and load channels from disk
458455 :param brain_regions: ibllib.atlas.regions.BrainRegions object - will label acronyms if provided
459456 :param return_collection:(bool - False) if True, returns the collection for loading the data
460457 :return: spikes, clusters (dict of bunch, 1 bunch per probe)
@@ -677,7 +674,7 @@ def load_wheel_reaction_times(eid, one=None):
677674 eid : [str, UUID, Path, dict]
678675 Experiment session identifier; may be a UUID, URL, experiment reference string
679676 details dict or Path
680- one : oneibl. one.OneAlyx, optional
677+ one : one.api .OneAlyx, optional
681678 one object to use for loading. Will generate internal one if not used, by default None
682679
683680 Returns
@@ -703,8 +700,9 @@ def load_wheel_reaction_times(eid, one=None):
703700 return firstMove_times - trials ['goCue_times' ]
704701
705702
706- def load_trials_df (eid , one = None , maxlen = None , t_before = 0. , t_after = 0. , ret_wheel = False ,
707- ret_abswheel = False , wheel_binsize = 0.02 , addtl_types = []):
703+ def load_trials_df (eid , one = None , maxlen = None , t_before = 0. , t_after = 0.2 , ret_wheel = False ,
704+ ret_abswheel = False , wheel_binsize = 0.02 , addtl_types = [],
705+ align_event = 'stimOn_times' , keeptrials = None ):
708706 """
709707 Generate a pandas dataframe of per-trial timing information about a given session.
710708 Each row in the frame will correspond to a single trial, with timing values indicating timing
@@ -720,7 +718,7 @@ def load_trials_df(eid, one=None, maxlen=None, t_before=0., t_after=0., ret_whee
720718 eid : [str, UUID, Path, dict]
721719 Experiment session identifier; may be a UUID, URL, experiment reference string
722720 details dict or Path
723- one : oneibl. one.OneAlyx, optional
721+ one : one.api .OneAlyx, optional
724722 one object to use for loading. Will generate internal one if not used, by default None
725723 maxlen : float, optional
726724 Maximum trial length for inclusion in df. Trials where feedback - response is longer
@@ -779,18 +777,19 @@ def remap_trialp(probs):
779777 endtimes = trials .feedback_times
780778 tmp = {key : value for key , value in trials .items () if key in trialstypes }
781779
782- if maxlen is not None :
783- with np .errstate (invalid = 'ignore' ):
784- keeptrials = (endtimes - starttimes ) <= maxlen
785- else :
786- keeptrials = range (len (starttimes ))
780+ if keeptrials is None :
781+ if maxlen is not None :
782+ with np .errstate (invalid = 'ignore' ):
783+ keeptrials = (endtimes - starttimes ) <= maxlen
784+ else :
785+ keeptrials = range (len (starttimes ))
787786 trialdata = {x : tmp [x ][keeptrials ] for x in trialstypes }
788787 trialdata ['probabilityLeft' ] = remap_trialp (trialdata ['probabilityLeft' ])
789788 trialsdf = pd .DataFrame (trialdata )
790789 if maxlen is not None :
791790 trialsdf .set_index (np .nonzero (keeptrials )[0 ], inplace = True )
792- trialsdf ['trial_start' ] = trialsdf ['stimOn_times' ] - t_before
793- trialsdf ['trial_end' ] = trialsdf ['feedback_times' ] + t_after
791+ trialsdf ['trial_start' ] = trialsdf [align_event ] - t_before
792+ trialsdf ['trial_end' ] = trialsdf [align_event ] + t_after
794793 tdiffs = trialsdf ['trial_end' ] - np .roll (trialsdf ['trial_start' ], - 1 )
795794 if np .any (tdiffs [:- 1 ] > 0 ):
796795 logging .warning (f'{ sum (tdiffs [:- 1 ] > 0 )} trials overlapping due to t_before and t_after '
@@ -879,16 +878,17 @@ class SpikeSortingLoader:
879878 SpikeSortingLoader(eid=eid, pname='probe00', one=one)
880879 - From a local session and probe name:
881880 SpikeSortingLoader(session_path=session_path, pname='probe00')
881+ NB: When no ONE instance is passed, any datasets that are loaded will not be recorded.
882882 """
883- one : ONE = None
883+ one : One = None
884884 atlas : None = None
885885 pid : str = None
886886 eid : str = ''
887887 pname : str = ''
888- # the following properties are the outcome of the post init funciton
889888 session_path : Path = ''
889+ # the following properties are the outcome of the post init function
890890 collections : list = None
891- datasets : list = None # list of all datasets belonging to the sesion
891+ datasets : list = None # list of all datasets belonging to the session
892892 # the following properties are the outcome of a reading function
893893 files : dict = None
894894 collection : str = ''
@@ -905,11 +905,14 @@ def __post_init__(self):
905905 self .session_path = self .one .eid2path (self .eid )
906906 # fully local providing a session path
907907 else :
908- self .one = One (cache_dir = self .session_path .parents [2 ], mode = 'local' )
909- df_sessions = cache ._make_sessions_df (self .session_path )
910- self .one ._cache ['sessions' ] = df_sessions .set_index ('id' )
911- self .one ._cache ['datasets' ] = cache ._make_datasets_df (self .session_path , hash_files = False )
912- self .eid = str (self .session_path .relative_to (self .session_path .parents [2 ]))
908+ if self .one :
909+ self .eid = self .one .to_eid (self .session_path )
910+ else :
911+ self .one = One (cache_dir = self .session_path .parents [2 ], mode = 'local' )
912+ df_sessions = cache ._make_sessions_df (self .session_path )
913+ self .one ._cache ['sessions' ] = df_sessions .set_index ('id' )
914+ self .one ._cache ['datasets' ] = cache ._make_datasets_df (self .session_path , hash_files = False )
915+ self .eid = str (self .session_path .relative_to (self .session_path .parents [2 ]))
913916 # populates default properties
914917 self .collections = self .one .list_collections (
915918 self .eid , filename = 'spikes*' , collection = f"alf/{ self .pname } *" )
@@ -930,7 +933,7 @@ def _get_attributes(dataset_types):
930933 cluster_attributes = list (set (CLUSTERS_ATTRIBUTES + cluster_attributes ))
931934 return spike_attributes , cluster_attributes
932935
933- def _get_spike_sorting_collection (self , spike_sorter = 'pykilosort' , revision = None ):
936+ def _get_spike_sorting_collection (self , spike_sorter = 'pykilosort' ):
934937 """
935938 Filters a list or array of collections to get the relevant spike sorting dataset
936939 if there is a pykilosort, load it
@@ -982,7 +985,7 @@ def load_spike_sorting(self, **kwargs):
982985 - alf: the final version of channel locations, same as resolved with the difference that data is on file
983986 - resolved: channel locations alignments have been agreed upon
984987 - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate
985- - traced: the histology track has been recovered from microscopy, however the depths may not match, inacurate data
988+ - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data
986989
987990 :param spike_sorter: (defaults to 'pykilosort')
988991 :param dataset_types: list of extra dataset types
@@ -1034,4 +1037,5 @@ def merge_clusters(spikes, clusters, channels, cache_dir=None):
10341037 @property
10351038 def url (self ):
10361039 """Gets flatiron URL for the session"""
1037- return str (self .session_path ).replace (str (self .one .alyx .cache_dir ), 'https://ibl.flatironinstitute.org' )
1040+ webclient = getattr (self .one , '_web_client' , None )
1041+ return webclient .rel_path2url (get_alf_path (self .session_path )) if webclient else None
0 commit comments