1212import matplotlib .pyplot as plt
1313
1414from one .api import ONE , One
15- from one .alf .files import get_alf_path
16- from one .alf .exceptions import ALFObjectNotFound
15+ from one .alf .files import get_alf_path , full_path_parts
16+ from one .alf .exceptions import ALFObjectNotFound , ALFMultipleCollectionsFound
1717from one .alf import cache
1818import one .alf .io as alfio
1919from neuropixel import TIP_SIZE_UM , trace_header
@@ -1324,18 +1324,48 @@ def load_session_data(self, trials=True, wheel=True, pose=True, motion_energy=Tr
13241324 _logger .warning (f"Could not load { row ['name' ]} data." )
13251325 _logger .debug (e )
13261326
1327- def load_trials (self ):
1327+ def _find_behaviour_collection (self , obj ):
1328+ """
1329+ Function to find the trial or wheel collection
1330+
1331+ Parameters
1332+ ----------
1333+ obj: str
1334+ Alf object to load, either 'trials' or 'wheel'
1335+ """
1336+ dataset = '_ibl_trials.table.pqt' if obj == 'trials' else '_ibl_wheel.position.npy'
1337+ dsets = self .one .list_datasets (self .eid , dataset )
1338+ if len (dsets ) == 0 :
1339+ return 'alf'
1340+ else :
1341+ collections = [full_path_parts (self .session_path .joinpath (d ), as_dict = True )['collection' ] for d in dsets ]
1342+ if len (set (collections )) == 1 :
1343+ return collections [0 ]
1344+ else :
1345+ _logger .error (f'Multiple collections found { collections } . Specify collection when loading, '
1346+ f'e.g sl.load_{ obj } (collection="{ collections [0 ]} ")' )
1347+ raise ALFMultipleCollectionsFound
1348+
1349+ def load_trials (self , collection = None ):
13281350 """
13291351 Function to load trials data into SessionLoader.trials
1352+
1353+ Parameters
1354+ ----------
1355+ collection: str
1356+ Alf collection of trials data
13301357 """
1358+
1359+ if not collection :
1360+ collection = self ._find_behaviour_collection ('trials' )
13311361 # itiDuration frequently has a mismatched dimension, and we don't need it, exclude using regex
13321362 self .one .wildcards = False
13331363 self .trials = self .one .load_object (
1334- self .eid , 'trials' , collection = 'alf' , attribute = r'(?!itiDuration).*' ).to_df ()
1364+ self .eid , 'trials' , collection = collection , attribute = r'(?!itiDuration).*' ).to_df ()
13351365 self .one .wildcards = True
13361366 self .data_info .loc [self .data_info ['name' ] == 'trials' , 'is_loaded' ] = True
13371367
1338- def load_wheel (self , fs = 1000 , corner_frequency = 20 , order = 8 ):
1368+ def load_wheel (self , fs = 1000 , corner_frequency = 20 , order = 8 , collection = None ):
13391369 """
13401370 Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position
13411371 is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which
@@ -1349,8 +1379,12 @@ def load_wheel(self, fs=1000, corner_frequency=20, order=8):
13491379 Corner frequency of Butterworth low-pass filter, default is 20
13501380 order: int, float
13511381 Order of Butterworth low_pass filter, default is 8
1382+ collection: str
1383+ Alf collection of wheel data
13521384 """
1353- wheel_raw = self .one .load_object (self .eid , 'wheel' )
1385+ if not collection :
1386+ collection = self ._find_behaviour_collection ('wheel' )
1387+ wheel_raw = self .one .load_object (self .eid , 'wheel' , collection = collection )
13541388 if wheel_raw ['position' ].shape [0 ] != wheel_raw ['timestamps' ].shape [0 ]:
13551389 raise ValueError ("Length mismatch between 'wheel.position' and 'wheel.timestamps" )
13561390 # resample the wheel position and compute velocity, acceleration
0 commit comments