Skip to content

Commit 2845311

Browse files
committed
first prototype session loader
1 parent affbaf8 commit 2845311

File tree

1 file changed

+87
-1
lines changed

1 file changed

+87
-1
lines changed

brainbox/io/one.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment."""
2-
from dataclasses import dataclass
2+
from dataclasses import dataclass, field
33
import logging
44
import os
55
from pathlib import Path
@@ -24,6 +24,7 @@
2424
from brainbox.core import TimeSeries
2525
from brainbox.processing import sync
2626
from brainbox.metrics.single_units import quick_unit_metrics
27+
from brainbox.behavior.wheel import interpolate_position, velocity_smoothed
2728

2829
_logger = logging.getLogger('ibllib')
2930

@@ -1056,3 +1057,88 @@ def samples2times(self, values, direction='forward'):
10561057
'reverse': interp1d(timestamps[:, 1], timestamps[:, 0], fill_value='extrapolate'),
10571058
}
10581059
return self._sync[direction](values)
1060+
1061+
1062+
@dataclass
1063+
class SessionLoader:
1064+
one: One = None
1065+
eid: str = ''
1066+
session_path: Path = ''
1067+
trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1068+
wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1069+
1070+
def __post_init__(self):
1071+
# Providing no session path, eid and one are required
1072+
if self.session_path is None or self.session_path == '':
1073+
if self.one and self.eid != '' and self.eid is not None:
1074+
self.session_path = self.one.eid2path(self.eid)
1075+
else:
1076+
raise ValueError("If no session path is given, one and eid are required.")
1077+
# Providing a session path
1078+
else:
1079+
if self.one:
1080+
self.eid = self.one.to_eid(self.session_path)
1081+
else:
1082+
# If no one is given, create from cache fully local
1083+
self.session_path = Path(self.session_path)
1084+
self.one = One(cache_dir=self.session_path.parents[2], mode='local')
1085+
df_sessions = cache._make_sessions_df(self.session_path)
1086+
self.one._cache['sessions'] = df_sessions.set_index('id')
1087+
self.one._cache['datasets'] = cache._make_datasets_df(self.session_path, hash_files=False)
1088+
self.eid = str(self.session_path.relative_to(self.session_path.parents[2]))
1089+
1090+
def load_session_data(self, wheel=True):
1091+
# TODO: Dont reload when data already loaded?
1092+
if wheel:
1093+
try:
1094+
self.load_wheel()
1095+
except BaseException as e:
1096+
_logger.warning("Could not load wheel data.")
1097+
_logger.debug(e)
1098+
1099+
def load_trials(self, align_event=None, pre_event=0.5, post_event=0.5):
1100+
self.trials = self.one.load_object(self.eid, 'trials').to_df()
1101+
1102+
def align_trials_to_event(self, align_event='stimOn_times', pre_event=0.5, post_event=0.5):
1103+
possible_events = ['stimOn_times', 'goCue_times', 'goCueTrigger_times',
1104+
'response_times', 'feedback_times', 'firstMovement_times']
1105+
if align_event not in possible_events:
1106+
raise ValueError(f"Argument align_event must be on of {possible_events}")
1107+
1108+
if self.trials.shape == (0, 0):
1109+
self.load_trials()
1110+
1111+
align_str = f"align_{align_event.split('_')[0]}"
1112+
self.trials[f'{align_str}_start'] = self.trials[align_event] - pre_event
1113+
self.trials[f'{align_str}_end'] = self.trials[align_event] + post_event
1114+
diffs = self.trials[f'{align_str}_end'] - np.roll(self.trials[f'{align_str}_start'], -1)
1115+
if np.any(diffs[:-1] > 0):
1116+
_logger.warning(f'{sum(diffs[:-1] > 0)} trials overlapping, try reducing pre_event, post_event or both!')
1117+
1118+
def load_wheel(self, sampling_rate=1000, smooth_size=0.03):
1119+
wheel_pos_raw = self.one.load_dataset(self.eid, '_ibl_wheel.position.npy')
1120+
wheel_times_raw = self.one.load_dataset(self.eid, '_ibl_wheel.timestamps.npy')
1121+
if wheel_times_raw.shape[0] != wheel_pos_raw.shape[0]:
1122+
raise ValueError("Length mismatch between '_ibl_wheel.position.npy' and '_ibl_wheel.timestamps.npy")
1123+
# resample the wheel position and compute velocity, acceleration
1124+
self.wheel = pd.DataFrame(columns=['times', 'position', 'velocity', 'acceleration'])
1125+
self.wheel['position'], self.wheel['times'] = interpolate_position(
1126+
wheel_times_raw, wheel_pos_raw, freq=sampling_rate)
1127+
self.wheel['velocity'], self.wheel['acceleration'] = velocity_smoothed(
1128+
self.wheel['position'], freq=sampling_rate, smooth_size=smooth_size)
1129+
1130+
def load_pose(self):
1131+
pass
1132+
1133+
def load_pose_speed(self):
1134+
pass
1135+
1136+
def load_licks(self):
1137+
pass
1138+
1139+
def load_sniffs(self):
1140+
pass
1141+
1142+
def load_pupil_diameter(self):
1143+
pass
1144+

0 commit comments

Comments
 (0)