2525from brainbox .processing import sync
2626from brainbox .metrics .single_units import quick_unit_metrics
2727from brainbox .behavior .wheel import interpolate_position , velocity_smoothed
28- from brainbox .behavior .dlc import likelihood_threshold
28+ from brainbox .behavior .dlc import likelihood_threshold , get_pupil_diameter , get_smooth_pupil_diameter
2929
3030_logger = logging .getLogger ('ibllib' )
3131
@@ -1065,10 +1065,12 @@ class SessionLoader:
10651065 one : One = None
10661066 eid : str = ''
10671067 session_path : Path = ''
1068+ data_info : pd .DataFrame = field (default_factory = pd .DataFrame , repr = False )
10681069 trials : pd .DataFrame = field (default_factory = pd .DataFrame , repr = False )
10691070 wheel : pd .DataFrame = field (default_factory = pd .DataFrame , repr = False )
1070- poses : dict = field (default_factory = dict , repr = False )
1071+ pose : dict = field (default_factory = dict , repr = False )
10711072 motion_energy : dict = field (default_factory = dict , repr = False )
1073+ pupil : pd .DataFrame = field (default_factory = pd .DataFrame , repr = False )
10721074
10731075 def __post_init__ (self ):
10741076 # Providing no session path, eid and one are required
@@ -1089,23 +1091,51 @@ def __post_init__(self):
10891091 self .one ._cache ['sessions' ] = df_sessions .set_index ('id' )
10901092 self .one ._cache ['datasets' ] = cache ._make_datasets_df (self .session_path , hash_files = False )
10911093 self .eid = str (self .session_path .relative_to (self .session_path .parents [2 ]))
1092-
1093- def load_session_data (self , trials = True , wheel = True , poses = True , motion_energy = True ):
1094- # TODO: Dont reload when data already loaded?
1095- names = ['trials' , 'wheel' , 'poses' , 'motion_energy' ]
1096- args = [trials , wheel , poses , motion_energy ]
1097- loading_funcs = [self .load_trials , self .load_wheel , self .load_poses , self .load_motion_energy ]
1098-
1099- for name , arg , loading_func in zip (names , args , loading_funcs ):
1100- if arg is True :
1094+ # Information of data that is and can be loaded
1095+ data_names = [
1096+ 'trials' ,
1097+ 'wheel' ,
1098+ 'poses' ,
1099+ 'motion_energy' ,
1100+ 'pupil'
1101+ ]
1102+ self .data_info = pd .DataFrame (columns = ['name' , 'is_loaded' ], data = zip (data_names , [False ]* len (data_names )))
1103+
1104+ def load_session_data (self , trials = True , wheel = True , poses = True , motion_energy = True , pupil = True , reload = False ):
1105+
1106+ load_df = self .data_info .copy ()
1107+ load_df ['to_load' ] = [
1108+ trials ,
1109+ wheel ,
1110+ poses ,
1111+ motion_energy ,
1112+ pupil
1113+ ]
1114+ load_df ['load_func' ] = [
1115+ self .load_trials ,
1116+ self .load_wheel ,
1117+ self .load_pose ,
1118+ self .load_motion_energy ,
1119+ self .load_pupil
1120+ ]
1121+
1122+ for idx , row in load_df .iterrows ():
1123+ if row ['to_load' ] is False :
1124+ _logger .debug (f"Not loading { row ['name' ]} data, set to False." )
1125+ elif row ['is_loaded' ] is True and reload is False :
1126+ _logger .debug (f"Not loading { row ['name' ]} data, is already loaded and reload=False." )
1127+ else :
11011128 try :
1102- loading_func ()
1129+ _logger .info (f"Loading { row ['name' ]} data" )
1130+ row ['load_func' ]()
1131+ self .data_info .loc [idx , 'is_loaded' ] = True
11031132 except BaseException as e :
1104- _logger .warning (f"Could not load { name } data." )
1133+ _logger .warning (f"Could not load { row [ ' name' ] } data." )
11051134 _logger .debug (e )
11061135
11071136 def load_trials (self ):
11081137 self .trials = self .one .load_object (self .eid , 'trials' ).to_df ()
1138+ self .data_info .loc [self .data_info ['name' ] == 'trials' , 'is_loaded' ] = True
11091139
11101140 def load_wheel (self , sampling_rate = 1000 , smooth_size = 0.03 ):
11111141 wheel_raw = self .one .load_object (self .eid , 'wheel' )
@@ -1118,15 +1148,17 @@ def load_wheel(self, sampling_rate=1000, smooth_size=0.03):
11181148 wheel_raw ['timestamps' ], wheel_raw ['position' ], freq = sampling_rate )
11191149 self .wheel ['velocity' ], self .wheel ['acceleration' ] = velocity_smoothed (
11201150 self .wheel ['position' ], freq = sampling_rate , smooth_size = smooth_size )
1151+ self .data_info .loc [self .data_info ['name' ] == 'wheel' , 'is_loaded' ] = True
11211152
1122- def load_poses (self , likelihood_thr = 0.9 , views = ['left' , 'right' , 'body' ]):
1153+ def load_pose (self , likelihood_thr = 0.9 , views = ['left' , 'right' , 'body' ]):
11231154 for view in views :
11241155 try :
11251156 pose_raw = self .one .load_object (self .eid , f'{ view } Camera' , attribute = ['dlc' , 'times' ])
11261157 # Double check if video timestamps are correct length or can be fixed
11271158 times_fixed , dlc = self ._check_video_timestamps (view , pose_raw ['times' ], pose_raw ['dlc' ])
1128- self .poses [f'{ view } Camera' ] = likelihood_threshold (dlc , likelihood_thr )
1129- self .poses [f'{ view } Camera' ].insert (0 , 'times' , times_fixed )
1159+ self .pose [f'{ view } Camera' ] = likelihood_threshold (dlc , likelihood_thr )
1160+ self .pose [f'{ view } Camera' ].insert (0 , 'times' , times_fixed )
1161+ self .data_info .loc [self .data_info ['name' ] == 'pose' , 'is_loaded' ] = True
11301162 except BaseException as e :
11311163 _logger .error (f'Could not load pose data for { view } Camera. Skipping camera.' )
11321164 _logger .debug (e )
@@ -1143,15 +1175,51 @@ def load_motion_energy(self, views=['left', 'right', 'body']):
11431175 view , me_raw ['times' ], me_raw ['ROIMotionEnergy' ])
11441176 self .motion_energy [f'{ view } Camera' ] = pd .DataFrame (columns = [names [view ]], data = motion_energy )
11451177 self .motion_energy [f'{ view } Camera' ].insert (0 , 'times' , times_fixed )
1178+ self .data_info .loc [self .data_info ['name' ] == 'motion_energy' , 'is_loaded' ] = True
11461179 except BaseException as e :
11471180 _logger .error (f'Could not load motion energy data for { view } Camera. Skipping camera.' )
11481181 _logger .debug (e )
11491182
11501183 def load_licks (self ):
11511184 pass
11521185
1153- def load_pupil_diameter (self ):
1154- pass
1186+ def load_pupil (self , snr_thresh = 5 ):
1187+ # Try to load from features
1188+ feat_raw = self .one .load_object (self .eid , 'leftCamera' , attribute = ['times' , 'features' ])
1189+ if 'features' in feat_raw .keys ():
1190+ times_fixed , feats = self ._check_video_timestamps (feat_raw ['times' ], feat_raw ['features' ])
1191+ self .pupil = feats .copy ()
1192+ self .pupil .insert (0 , 'times' , times_fixed )
1193+
1194+ # If unavailable compute on the fly
1195+ else :
1196+ _logger .info ('Pupil diameter not available, trying to compute on the fly.' )
1197+ if self .data_info [self .data_info ['name' ] == 'pose' , 'is_loaded' ] and 'leftCamera' in self .pose .keys ():
1198+ # If pose data is already loaded, we don't know if it was threshold at 0.9, so we need a little stunt
1199+ copy_pose = self .pose ['leftCamera' ].copy () # Save the previously loaded pose data
1200+ self .load_pose (views = ['left' ], likelihood_thr = 0.9 ) # Load new with threshold 0.9
1201+ dlc_thr = self .pose ['leftCamera' ].copy () # Save the threshold pose data in new variable
1202+ self .pose ['leftCamera' ] = copy_pose .copy () # Get previously loaded pose data back in place
1203+ else :
1204+ self .load_pose (views = ['left' ], likelihood_thr = 0.9 )
1205+ dlc_thr = self .pose ['leftCamera' ].copy ()
1206+
1207+ self .pupil ['pupilDiameter_raw' ] = get_pupil_diameter (dlc_thr )
1208+ try :
1209+ self .pupil ['pupilDiameter_smooth' ] = get_smooth_pupil_diameter (self .pupil ['pupilDiameter_raw' ], 'left' )
1210+ except BaseException as e :
1211+ _logger .error (f"Computing smooth pupil diameter failed, saving all NaNs." )
1212+ _logger .debug (e )
1213+ self .pupil ['pupilDiameter_smooth' ] = np .nan
1214+
1215+ if not np .all (np .isnan (self .pupil ['pupilDiameter_smooth' ])):
1216+ good_idxs = np .where (
1217+ ~ np .isnan (self .pupil ['pupilDiameter_smooth' ]) & ~ np .isnan (self .pupil ['pupilDiameter_raw' ]))[0 ]
1218+ snr = (np .var (self .pupil ['pupilDiameter_smooth' ][good_idxs ]) /
1219+ (np .var (self .pupil ['pupilDiameter_smooth' ][good_idxs ] - self .pupil ['pupilDiameter_raw' ][good_idxs ])))
1220+ if snr < snr_thresh :
1221+ self .pupil = pd .DataFrame
1222+ raise ValueError (f'Pupil diameter SNR ({ snr :.2f} ) below threshold SNR ({ snr_thresh } ), removing data.' )
11551223
11561224 def align_trials_to_event (self , align_event = 'stimOn_times' , pre_event = 0.5 , post_event = 0.5 ):
11571225 possible_events = ['stimOn_times' , 'goCue_times' , 'goCueTrigger_times' ,
0 commit comments