Skip to content

Commit 7bd99de

Browse files
committed
Merge remote-tracking branch 'origin/develop' into docs
2 parents 8c8b3d7 + e06aca2 commit 7bd99de

File tree

1 file changed

+40
-6
lines changed

1 file changed

+40
-6
lines changed

brainbox/io/one.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
import matplotlib.pyplot as plt
1313

1414
from one.api import ONE, One
15-
from one.alf.files import get_alf_path
16-
from one.alf.exceptions import ALFObjectNotFound
15+
from one.alf.files import get_alf_path, full_path_parts
16+
from one.alf.exceptions import ALFObjectNotFound, ALFMultipleCollectionsFound
1717
from one.alf import cache
1818
import one.alf.io as alfio
1919
from neuropixel import TIP_SIZE_UM, trace_header
@@ -1324,18 +1324,48 @@ def load_session_data(self, trials=True, wheel=True, pose=True, motion_energy=Tr
13241324
_logger.warning(f"Could not load {row['name']} data.")
13251325
_logger.debug(e)
13261326

1327-
def load_trials(self):
1327+
def _find_behaviour_collection(self, obj):
1328+
"""
1329+
Function to find the trial or wheel collection
1330+
1331+
Parameters
1332+
----------
1333+
obj: str
1334+
Alf object to load, either 'trials' or 'wheel'
1335+
"""
1336+
dataset = '_ibl_trials.table.pqt' if obj == 'trials' else '_ibl_wheel.position.npy'
1337+
dsets = self.one.list_datasets(self.eid, dataset)
1338+
if len(dsets) == 0:
1339+
return 'alf'
1340+
else:
1341+
collections = [full_path_parts(self.session_path.joinpath(d), as_dict=True)['collection'] for d in dsets]
1342+
if len(set(collections)) == 1:
1343+
return collections[0]
1344+
else:
1345+
_logger.error(f'Multiple collections found {collections}. Specify collection when loading, '
1346+
f'e.g sl.load_{obj}(collection="{collections[0]}")')
1347+
raise ALFMultipleCollectionsFound
1348+
1349+
def load_trials(self, collection=None):
13281350
"""
13291351
Function to load trials data into SessionLoader.trials
1352+
1353+
Parameters
1354+
----------
1355+
collection: str
1356+
Alf collection of trials data
13301357
"""
1358+
1359+
if not collection:
1360+
collection = self._find_behaviour_collection('trials')
13311361
# itiDuration frequently has a mismatched dimension, and we don't need it, exclude using regex
13321362
self.one.wildcards = False
13331363
self.trials = self.one.load_object(
1334-
self.eid, 'trials', collection='alf', attribute=r'(?!itiDuration).*').to_df()
1364+
self.eid, 'trials', collection=collection, attribute=r'(?!itiDuration).*').to_df()
13351365
self.one.wildcards = True
13361366
self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True
13371367

1338-
def load_wheel(self, fs=1000, corner_frequency=20, order=8):
1368+
def load_wheel(self, fs=1000, corner_frequency=20, order=8, collection=None):
13391369
"""
13401370
Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position
13411371
is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which
@@ -1349,8 +1379,12 @@ def load_wheel(self, fs=1000, corner_frequency=20, order=8):
13491379
Corner frequency of Butterworth low-pass filter, default is 20
13501380
order: int, float
13511381
Order of Butterworth low_pass filter, default is 8
1382+
collection: str
1383+
Alf collection of wheel data
13521384
"""
1353-
wheel_raw = self.one.load_object(self.eid, 'wheel')
1385+
if not collection:
1386+
collection = self._find_behaviour_collection('wheel')
1387+
wheel_raw = self.one.load_object(self.eid, 'wheel', collection=collection)
13541388
if wheel_raw['position'].shape[0] != wheel_raw['timestamps'].shape[0]:
13551389
raise ValueError("Length mismatch between 'wheel.position' and 'wheel.timestamps")
13561390
# resample the wheel position and compute velocity, acceleration

0 commit comments

Comments
 (0)