Skip to content

Commit ff64d41

Browse files
committed
Begin docstrings
1 parent 89772d1 commit ff64d41

File tree

1 file changed

+111
-19
lines changed

1 file changed

+111
-19
lines changed

brainbox/io/one.py

Lines changed: 111 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,9 +1062,64 @@ def samples2times(self, values, direction='forward'):
10621062

10631063
@dataclass
10641064
class SessionLoader:
1065+
"""
1066+
Object to load session data for a give session in the recommended way.
1067+
1068+
Parameters
1069+
----------
1070+
one: one.api.ONE instance
1071+
Can be in remote or local mode (required)
1072+
session_path: string or pathlib.Path
1073+
The absolute path to the session (one of session_path or eid is required)
1074+
eid: string
1075+
database UUID of the session (one of session_path or eid is required)
1076+
1077+
If both are provided, session_path takes precedence over eid.
1078+
1079+
Examples
1080+
--------
1081+
1) Load all available session data for one session:
1082+
>>> from one.api import ONE
1083+
>>> from brainbox.io.one import SessionLoader
1084+
>>> one = ONE()
1085+
>>> sess_loader = SessionLoader(one=one, session_path='/mnt/s0/Data/Subjects/cortexlab/KS022/2019-12-10/001/')
1086+
# Object is initiated, but no data is loaded as you can see in the data_info attribute
1087+
>>> sess_loader.data_info
1088+
name is_loaded
1089+
0 trials False
1090+
1 wheel False
1091+
2 poses False
1092+
3 motion_energy False
1093+
4 pupil False
1094+
1095+
# Loading all available session data, the data_info attribute now shows which data has been loaded
1096+
>>> sess_loader.load_session_data()
1097+
>>> sess_loader.data_info
1098+
name is_loaded
1099+
0 trials True
1100+
1 wheel True
1101+
2 poses True
1102+
3 motion_energy True
1103+
4 pupil False
1104+
1105+
# You can access the data via the respective attributes, e.g.
1106+
>>> sess_loader.trials.shape
1107+
(626, 18)
1108+
# Each data comes with its own timestamps in a column called 'times'
1109+
>>> sess_loader.pose['bodyCamera']['times']
1110+
0 6.201239
1111+
1 6.234569
1112+
2 6.267899
1113+
3 6.301229
1114+
4 6.334592
1115+
...
1116+
# In order to control the loading of specific data by e.g. specifying parameters, use the individual loading
1117+
functions:
1118+
>>> sess_loader.load_wheel(sampling_rate=100)
1119+
"""
10651120
one: One = None
1066-
eid: str = ''
10671121
session_path: Path = ''
1122+
eid: str = ''
10681123
data_info: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
10691124
trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
10701125
wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
@@ -1073,12 +1128,17 @@ class SessionLoader:
10731128
pupil: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
10741129

10751130
def __post_init__(self):
1131+
"""
1132+
Function that runs automatically after initiation of the dataclass attributes.
1133+
Checks for required inputs, sets session_path and eid, creates data_info table.
1134+
"""
10761135
if self.one is None:
10771136
raise ValueError("An input to one is required. If not connection to a database is desired, it can be "
10781137
"a fully local instance of One.")
10791138
# If session path is given, takes precedence over eid
10801139
if self.session_path is not None and self.session_path != '':
10811140
self.eid = self.one.to_eid(self.session_path)
1141+
self.session_path = Path(self.session_path)
10821142
# Providing no session path, try to infer from eid
10831143
else:
10841144
if self.eid is not None and self.eid != '':
@@ -1096,7 +1156,30 @@ def __post_init__(self):
10961156
self.data_info = pd.DataFrame(columns=['name', 'is_loaded'], data=zip(data_names, [False]*len(data_names)))
10971157

10981158
def load_session_data(self, trials=True, wheel=True, poses=True, motion_energy=True, pupil=True, reload=False):
1099-
1159+
"""
1160+
Function to load available session data into the SessionLoader object. Input parameters allow to control which
1161+
data is loaded. Data is loaded into an attribute of the SessionLoader object with the same name as the input
1162+
parameter (e.g. SessionLoader.trials, SessionLoader.pose). Information about which data is loaded is stored
1163+
in SessionLoader.data_info
1164+
1165+
Parameters
1166+
----------
1167+
trials: boolean
1168+
Whether to load all trials data into SessionLoader.trials, default is True
1169+
wheel: boolean
1170+
Whether to load wheel data (position, velocity, acceleration) into SessionLoader.wheel, default is True
1171+
poses: boolean
1172+
Whether to load pose tracking results (DLC) for each available camera into SessionLoader.poses,
1173+
default is True
1174+
motion_energy: boolean
1175+
Whether to load motion energy data (whisker pad for left/right camera, body for body camera)
1176+
into SessionLoader.motion_energy, default is True
1177+
pupil: boolean
1178+
Whether to load pupil diameter (raw and smooth) for the left/right camera into SessionLoader.pupil,
1179+
default is True
1180+
reload: boolean
1181+
Whether to reload data that has already been loaded into this SessionLoader object, default is False
1182+
"""
11001183
load_df = self.data_info.copy()
11011184
load_df['to_load'] = [
11021185
trials,
@@ -1128,10 +1211,25 @@ def load_session_data(self, trials=True, wheel=True, poses=True, motion_energy=T
11281211
_logger.debug(e)
11291212

11301213
def load_trials(self):
1214+
"""
1215+
Function to load trials data into SessionLoader.trials
1216+
"""
11311217
self.trials = self.one.load_object(self.eid, 'trials').to_df()
11321218
self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True
11331219

11341220
def load_wheel(self, sampling_rate=1000, smooth_size=0.03):
1221+
"""
1222+
Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position
1223+
is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which
1224+
smoothing is applied.
1225+
1226+
Parameters
1227+
----------
1228+
sampling_rate: float
1229+
Rate at which to sample the wheel position
1230+
smooth_size: float
1231+
Kernel for smoothing the wheel data to compute velocity and acceleration
1232+
"""
11351233
wheel_raw = self.one.load_object(self.eid, 'wheel')
11361234
# TODO: Fix this instead of raising error?
11371235
if wheel_raw['position'].shape[0] != wheel_raw['timestamps'].shape[0]:
@@ -1145,6 +1243,17 @@ def load_wheel(self, sampling_rate=1000, smooth_size=0.03):
11451243
self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True
11461244

11471245
def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']):
1246+
"""
1247+
Function to load the pose estimation results (DLC) into SessionLoader.poses
1248+
Parameters
1249+
----------
1250+
likelihood_thr
1251+
views
1252+
1253+
Returns
1254+
-------
1255+
1256+
"""
11481257
for view in views:
11491258
try:
11501259
pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times'])
@@ -1215,23 +1324,6 @@ def load_pupil(self, snr_thresh=5):
12151324
self.pupil = pd.DataFrame
12161325
raise ValueError(f'Pupil diameter SNR ({snr:.2f}) below threshold SNR ({snr_thresh}), removing data.')
12171326

1218-
def align_trials_to_event(self, align_event='stimOn_times', pre_event=0.5, post_event=0.5):
1219-
possible_events = ['stimOn_times', 'goCue_times', 'goCueTrigger_times',
1220-
'response_times', 'feedback_times', 'firstMovement_times']
1221-
if align_event not in possible_events:
1222-
raise ValueError(f"Argument align_event must be on of {possible_events}")
1223-
1224-
if self.trials.shape == (0, 0):
1225-
_logger.info("No trials data loaded. Trying to load trials data.")
1226-
self.load_trials()
1227-
1228-
align_str = f"align_{align_event.split('_')[0]}"
1229-
self.trials[f'{align_str}_start'] = self.trials[align_event] - pre_event
1230-
self.trials[f'{align_str}_end'] = self.trials[align_event] + post_event
1231-
diffs = self.trials[f'{align_str}_end'] - np.roll(self.trials[f'{align_str}_start'], -1)
1232-
if np.any(diffs[:-1] > 0):
1233-
_logger.warning(f'{sum(diffs[:-1] > 0)} trials overlapping, try reducing pre_event, post_event or both!')
1234-
12351327
def _check_video_timestamps(self, view, video_timestamps, video_data):
12361328
# If camera times are shorter than video data, or empty, no current fix
12371329
if video_timestamps.shape[0] < video_data.shape[0]:

0 commit comments

Comments
 (0)