|
1 | 1 | """Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment.""" |
2 | | -from dataclasses import dataclass |
| 2 | +from dataclasses import dataclass, field |
3 | 3 | import logging |
4 | 4 | import os |
5 | 5 | from pathlib import Path |
|
24 | 24 | from brainbox.core import TimeSeries |
25 | 25 | from brainbox.processing import sync |
26 | 26 | from brainbox.metrics.single_units import quick_unit_metrics |
| 27 | +from brainbox.behavior.wheel import interpolate_position, velocity_smoothed |
27 | 28 |
|
28 | 29 | _logger = logging.getLogger('ibllib') |
29 | 30 |
|
@@ -1056,3 +1057,88 @@ def samples2times(self, values, direction='forward'): |
1056 | 1057 | 'reverse': interp1d(timestamps[:, 1], timestamps[:, 0], fill_value='extrapolate'), |
1057 | 1058 | } |
1058 | 1059 | return self._sync[direction](values) |
| 1060 | + |
| 1061 | + |
| 1062 | +@dataclass |
| 1063 | +class SessionLoader: |
| 1064 | + one: One = None |
| 1065 | + eid: str = '' |
| 1066 | + session_path: Path = '' |
| 1067 | + trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) |
| 1068 | + wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) |
| 1069 | + |
| 1070 | + def __post_init__(self): |
| 1071 | + # Providing no session path, eid and one are required |
| 1072 | + if self.session_path is None or self.session_path == '': |
| 1073 | + if self.one and self.eid != '' and self.eid is not None: |
| 1074 | + self.session_path = self.one.eid2path(self.eid) |
| 1075 | + else: |
| 1076 | + raise ValueError("If no session path is given, one and eid are required.") |
| 1077 | + # Providing a session path |
| 1078 | + else: |
| 1079 | + if self.one: |
| 1080 | + self.eid = self.one.to_eid(self.session_path) |
| 1081 | + else: |
| 1082 | + # If no one is given, create from cache fully local |
| 1083 | + self.session_path = Path(self.session_path) |
| 1084 | + self.one = One(cache_dir=self.session_path.parents[2], mode='local') |
| 1085 | + df_sessions = cache._make_sessions_df(self.session_path) |
| 1086 | + self.one._cache['sessions'] = df_sessions.set_index('id') |
| 1087 | + self.one._cache['datasets'] = cache._make_datasets_df(self.session_path, hash_files=False) |
| 1088 | + self.eid = str(self.session_path.relative_to(self.session_path.parents[2])) |
| 1089 | + |
| 1090 | + def load_session_data(self, wheel=True): |
| 1091 | + # TODO: Dont reload when data already loaded? |
| 1092 | + if wheel: |
| 1093 | + try: |
| 1094 | + self.load_wheel() |
| 1095 | + except BaseException as e: |
| 1096 | + _logger.warning("Could not load wheel data.") |
| 1097 | + _logger.debug(e) |
| 1098 | + |
| 1099 | + def load_trials(self, align_event=None, pre_event=0.5, post_event=0.5): |
| 1100 | + self.trials = self.one.load_object(self.eid, 'trials').to_df() |
| 1101 | + |
| 1102 | + def align_trials_to_event(self, align_event='stimOn_times', pre_event=0.5, post_event=0.5): |
| 1103 | + possible_events = ['stimOn_times', 'goCue_times', 'goCueTrigger_times', |
| 1104 | + 'response_times', 'feedback_times', 'firstMovement_times'] |
| 1105 | + if align_event not in possible_events: |
| 1106 | + raise ValueError(f"Argument align_event must be on of {possible_events}") |
| 1107 | + |
| 1108 | + if self.trials.shape == (0, 0): |
| 1109 | + self.load_trials() |
| 1110 | + |
| 1111 | + align_str = f"align_{align_event.split('_')[0]}" |
| 1112 | + self.trials[f'{align_str}_start'] = self.trials[align_event] - pre_event |
| 1113 | + self.trials[f'{align_str}_end'] = self.trials[align_event] + post_event |
| 1114 | + diffs = self.trials[f'{align_str}_end'] - np.roll(self.trials[f'{align_str}_start'], -1) |
| 1115 | + if np.any(diffs[:-1] > 0): |
| 1116 | + _logger.warning(f'{sum(diffs[:-1] > 0)} trials overlapping, try reducing pre_event, post_event or both!') |
| 1117 | + |
| 1118 | + def load_wheel(self, sampling_rate=1000, smooth_size=0.03): |
| 1119 | + wheel_pos_raw = self.one.load_dataset(self.eid, '_ibl_wheel.position.npy') |
| 1120 | + wheel_times_raw = self.one.load_dataset(self.eid, '_ibl_wheel.timestamps.npy') |
| 1121 | + if wheel_times_raw.shape[0] != wheel_pos_raw.shape[0]: |
| 1122 | + raise ValueError("Length mismatch between '_ibl_wheel.position.npy' and '_ibl_wheel.timestamps.npy") |
| 1123 | + # resample the wheel position and compute velocity, acceleration |
| 1124 | + self.wheel = pd.DataFrame(columns=['times', 'position', 'velocity', 'acceleration']) |
| 1125 | + self.wheel['position'], self.wheel['times'] = interpolate_position( |
| 1126 | + wheel_times_raw, wheel_pos_raw, freq=sampling_rate) |
| 1127 | + self.wheel['velocity'], self.wheel['acceleration'] = velocity_smoothed( |
| 1128 | + self.wheel['position'], freq=sampling_rate, smooth_size=smooth_size) |
| 1129 | + |
| 1130 | + def load_pose(self): |
| 1131 | + pass |
| 1132 | + |
| 1133 | + def load_pose_speed(self): |
| 1134 | + pass |
| 1135 | + |
| 1136 | + def load_licks(self): |
| 1137 | + pass |
| 1138 | + |
| 1139 | + def load_sniffs(self): |
| 1140 | + pass |
| 1141 | + |
| 1142 | + def load_pupil_diameter(self): |
| 1143 | + pass |
| 1144 | + |
0 commit comments