@@ -135,6 +135,7 @@ def __init__(self, session_path_or_eid, camera, **kwargs):
135135 self .n_samples = kwargs .pop ('n_samples' , 100 )
136136 self .sync_collection = kwargs .pop ('sync_collection' , None )
137137 self .sync = kwargs .pop ('sync_type' , None )
138+ self .protocol = kwargs .pop ('protocol' , None )
138139 super ().__init__ (session_path_or_eid , ** kwargs )
139140
140141 # Data
@@ -163,7 +164,10 @@ def __init__(self, session_path_or_eid, camera, **kwargs):
163164 self .outcome = spec .QC .NOT_SET
164165
165166 # Specify any checks to remove
166- self .checks_to_remove = []
167+ if self .protocol is not None and 'habituation' in self .protocol :
168+ self .checks_to_remove = ['check_wheel_alignment' ]
169+ else :
170+ self .checks_to_remove = []
167171 self ._type = None
168172
169173 @property
@@ -271,8 +275,12 @@ def load_data(self, extract_times: bool = False, load_video: bool = True) -> Non
271275 else :
272276 raise NotImplementedError (f'Unknown namespace "{ ns } "' )
273277 else :
274- wheel_data = training_wheel .get_wheel_position (
275- self .session_path , task_collection = task_collection )
278+ if self .protocol is not None and 'habituation' in self .protocol :
279+ wheel_data = training_wheel .get_wheel_position (
280+ self .session_path , task_collection = task_collection )
281+ else :
282+ wheel_data = [None , None ]
283+
276284 self .data ['wheel' ] = Bunch (zip (wheel_keys , wheel_data ))
277285
278286 # Find short period of wheel motion for motion correlation.
0 commit comments