Skip to content

Commit 9e0cdbc

Browse files
committed
pose and motion energy loading
1 parent e2fa36d commit 9e0cdbc

File tree

1 file changed

+40
-21
lines changed

1 file changed

+40
-21
lines changed

brainbox/io/one.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,6 +1068,7 @@ class SessionLoader:
10681068
trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
10691069
wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
10701070
poses: dict = field(default_factory=dict, repr=False)
1071+
motion_energy: dict = field(default_factory=dict, repr=False)
10711072

10721073
def __post_init__(self):
10731074
# Providing no session path, eid and one are required
@@ -1089,11 +1090,11 @@ def __post_init__(self):
10891090
self.one._cache['datasets'] = cache._make_datasets_df(self.session_path, hash_files=False)
10901091
self.eid = str(self.session_path.relative_to(self.session_path.parents[2]))
10911092

1092-
def load_session_data(self, trials=True, wheel=True, poses=True):
1093+
def load_session_data(self, trials=True, wheel=True, poses=True, motion_energy=True):
10931094
# TODO: Dont reload when data already loaded?
1094-
names = ['trials', 'wheel', 'poses']
1095-
args = [trials, wheel, poses]
1096-
loading_funcs = [self.load_trials, self.load_wheel, self.load_poses]
1095+
names = ['trials', 'wheel', 'poses', 'motion_energy']
1096+
args = [trials, wheel, poses, motion_energy]
1097+
loading_funcs = [self.load_trials, self.load_wheel, self.load_poses, self.load_motion_energy]
10971098

10981099
for name, arg, loading_func in zip(names, args, loading_funcs):
10991100
if arg is True:
@@ -1121,31 +1122,30 @@ def load_wheel(self, sampling_rate=1000, smooth_size=0.03):
11211122
def load_poses(self, likelihood_thr=0.9, views=['left', 'right', 'body']):
11221123
for view in views:
11231124
try:
1124-
dlc_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times'])
1125-
# Sometimes the camera times exist but are empty
1126-
if dlc_raw['times'].shape[0] == 0:
1127-
_logger.error(f'Camera times empty for {view}Camera. Skipping camera.')
1128-
# For pre-GPIO sessions, it is possible that the camera times are longer than the actual video.
1129-
# This is because the first few frames are sometimes not recorded. We can remove the first few
1130-
# timestamps in this case
1131-
elif dlc_raw['times'].shape[0] > dlc_raw['dlc'].shape[0]:
1132-
dlc_raw['times'][-dlc_raw['dlc'].shape[0]:]
1133-
elif dlc_raw['times'].shape[0] < dlc_raw['dlc'].shape[0]:
1134-
_logger.error(f'Camera times are shorter than pose estimation for {view}Camera. Skipping camera.')
1135-
else:
1136-
self.poses[view] = likelihood_threshold(dlc_raw['dlc'], likelihood_thr)
1137-
self.poses[view].insert(0, 'times', dlc_raw['times'])
1125+
pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times'])
1126+
# Double check if video timestamps are correct length or can be fixed
1127+
times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw['dlc'])
1128+
self.poses[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr)
1129+
self.poses[f'{view}Camera'].insert(0, 'times', times_fixed)
11381130
except BaseException as e:
11391131
_logger.error(f'Could not load pose data for {view}Camera. Skipping camera.')
11401132
_logger.debug(e)
11411133

11421134
def load_motion_energy(self, views=['left', 'right', 'body']):
1143-
names = {'left': 'whisker_left',
1144-
'right': 'whisker_right',
1145-
'body': 'body'}
1135+
names = {'left': 'whiskerMotionEnergy',
1136+
'right': 'whiskerMotionEnergy',
1137+
'body': 'bodyMotionEnergy'}
11461138
for view in views:
11471139
try:
11481140
me_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times'])
1141+
# Double check if video timestamps are correct length or can be fixed
1142+
times_fixed, motion_energy = self._check_video_timestamps(
1143+
view, me_raw['times'], me_raw['ROIMotionEnergy'])
1144+
self.motion_energy[f'{view}Camera'] = pd.DataFrame(columns=[names[view]], data=motion_energy)
1145+
self.motion_energy[f'{view}Camera'].insert(0, 'times', times_fixed)
1146+
except BaseException as e:
1147+
_logger.error(f'Could not load motion energy data for {view}Camera. Skipping camera.')
1148+
_logger.debug(e)
11491149

11501150
def load_licks(self):
11511151
pass
@@ -1169,3 +1169,22 @@ def align_trials_to_event(self, align_event='stimOn_times', pre_event=0.5, post_
11691169
diffs = self.trials[f'{align_str}_end'] - np.roll(self.trials[f'{align_str}_start'], -1)
11701170
if np.any(diffs[:-1] > 0):
11711171
_logger.warning(f'{sum(diffs[:-1] > 0)} trials overlapping, try reducing pre_event, post_event or both!')
1172+
1173+
def _check_video_timestamps(self, view, video_timestamps, video_data):
1174+
pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times'])
1175+
# If camera times are shorter than video data, or empty, no current fix
1176+
if video_timestamps.shape[0] < video_data.shape[0]:
1177+
if video_timestamps.shape[0] == 0:
1178+
msg = f'Camera times empty for {view}Camera.'
1179+
else:
1180+
msg = f'Camera times are shorter than video data for {view}Camera.'
1181+
_logger.warning(msg)
1182+
raise ValueError(msg)
1183+
# For pre-GPIO sessions, it is possible that the camera times are longer than the actual video.
1184+
# This is because the first few frames are sometimes not recorded. We can remove the first few
1185+
# timestamps in this case
1186+
elif video_timestamps.shape[0] > video_data.shape[0]:
1187+
video_timestamps_fixed = video_timestamps[-video_data.shape[0]:]
1188+
return video_timestamps_fixed, video_data
1189+
else:
1190+
return video_timestamps, video_data

0 commit comments

Comments
 (0)