11import logging
22
3+ import cv2
34import numpy as np
45import pandas as pd
56import one .alf .io as alfio
89
910from ibllib .pipes .base_tasks import BehaviourTask
1011from ibllib .exceptions import SyncBpodFpgaException
12+ from ibllib .io .video import get_video_meta
1113from ibllib .io .extractors .ephys_fpga import get_protocol_period , get_sync_fronts
1214from ibllib .io .raw_daq_loaders import load_timeline_sync_and_chmap
1315from ibllib .io .extractors .mesoscope import plot_timeline
1416
1517_logger = logging .getLogger ('ibllib' ).getChild (__name__ )
18+ _logger .setLevel (logging .DEBUG )
1619
1720
1821class PassiveVideoTimeline (BehaviourTask ):
@@ -25,6 +28,7 @@ def signature(self):
2528 signature = {}
2629 signature ['input_files' ] = [
2730 ('_sp_taskData.raw.*' , self .collection , True ), # TODO Create dataset type?
31+ ('_sp_video.raw.*' , self .collection , False ),
2832 ('_iblrig_taskSettings.raw.*' , self .collection , True ),
2933 (f'_{ self .sync_namespace } _DAQdata.raw.npy' , self .sync_collection , True ),
3034 (f'_{ self .sync_namespace } _DAQdata.timestamps.npy' , self .sync_collection , True ),
@@ -63,7 +67,27 @@ def generate_sync_sequence(seed=1234, ns=3600, res=8):
6367 finally :
6468 np .random .set_state (state )
6569
66- def extract_frame_times (self , save = True , frame_rate = 60 , display = False , ** kwargs ):
70+ def load_sync_sequence_from_video (self , video_file , location = 'bottom right' , size = (5 , 5 )):
71+ cap = cv2 .VideoCapture (str (video_file ))
72+ sequence = []
73+ location = location .casefold ().split ()
74+ loc_map = {
75+ 'top' : slice (0 , size [1 ]), 'bottom' : slice (- size [1 ], None ),
76+ 'left' : slice (0 , size [0 ]), 'right' : slice (- size [0 ], None )}
77+ idx = tuple (loc_map [x ] for x in reversed (location )) # h, w
78+ success = True
79+ while success :
80+ success , frame = cap .read ()
81+ if success :
82+ # Find the sync square in the video frame
83+ pixel = np .mean (frame [idx ])
84+ sequence .append (int (pixel > 128 ))
85+ length = int (cap .get (cv2 .CAP_PROP_FRAME_COUNT ))
86+ cap .release ()
87+ assert len (sequence ) == length , 'sequence length does not match video length'
88+ return np .array (sequence )
89+
90+ def extract_frame_times (self , save = True , frame_rate = None , display = False , ** kwargs ):
6791 """Extract the Bpod trials data and Timeline acquired signals.
6892
6993 Sync requires three steps:
@@ -76,7 +100,7 @@ def extract_frame_times(self, save=True, frame_rate=60, display=False, **kwargs)
76100 save : bool, optional
77101 Whether to save the video frame times to file, by default True.
78102 frame_rate : int, optional
79- The frame rate of the video presented, by default 60 .
103+ The frame rate of the video presented, by default 30 .
80104 display : bool, optional
81105 When true, plot the aligned frame times. By default False.
82106
@@ -97,13 +121,32 @@ def extract_frame_times(self, save=True, frame_rate=60, display=False, **kwargs)
97121 SyncBpodFpgaException
98122 The synchronization of frame times was likely unsuccessful.
99123 """
124+ DEFAULT_FRAME_RATE = 30
100125 _ , (p ,), _ = self .input_files [0 ].find_files (self .session_path )
101126 # Load raw data
102127 proc_data = pd .read_parquet (p )
103128 sync_path = self .session_path / self .sync_collection
104129 self .timeline = alfio .load_object (sync_path , 'DAQdata' , namespace = 'timeline' )
105130 sync , chmap = load_timeline_sync_and_chmap (sync_path , timeline = self .timeline )
106131
132+ # Attempt to get the frame rate from the video file if not provided
133+ video_file = next (self .session_path .joinpath (self .collection ).glob ('_sp_video.raw.*' ))
134+ if video_file .exists ():
135+ video_meta = get_video_meta (video_file )
136+ if frame_rate is not None and frame_rate != video_meta .fps :
137+ _logger .warning (
138+ 'Frame rate mismatch: %.2f Hz (video) vs %.2f Hz (provided). Using %.2f Hz' ,
139+ video_meta .fps , frame_rate , video_meta .fps )
140+ else :
141+ _logger .debug ('Video frame rate: %.2f Hz' , video_meta .fps )
142+ frame_rate = video_meta .fps
143+ else :
144+ video_meta = None
145+ frame_rate = frame_rate or DEFAULT_FRAME_RATE
146+ _logger .warning ('Video not found. Assumed video frame rate: %.2f Hz' , frame_rate )
147+ Fs = self .timeline ['meta' ]['daqSampleRate' ]
148+ assert Fs > frame_rate * 1.5 , 'DAQ sample rate must be higher than video frame rate'
149+
107150 bpod = get_sync_fronts (sync , chmap ['bpod' ])
108151 # Get the spacer times for this protocol
109152 if any (arg in kwargs for arg in ('tmin' , 'tmax' )):
@@ -133,7 +176,8 @@ def extract_frame_times(self, save=True, frame_rate=60, display=False, **kwargs)
133176 # These durations are longer than video actually played and will be cut down after
134177 durations = (proc_data ['intervals_1' ] - proc_data ['intervals_0' ]).values
135178 max_n_frames = np .max (np .ceil (durations * frame_rate ).astype (int ))
136- frame_times = np .full ((max_n_frames , len (proc_data )), np .nan )
179+ n_frames = video_meta .length if video_meta else max_n_frames
180+ frame_times = np .full ((n_frames , len (proc_data )), np .nan )
137181
138182 sync_sequence = kwargs .get ('sync_sequence' , self .generate_sync_sequence ())
139183 for i , rep in proc_data .iterrows ():
@@ -147,6 +191,9 @@ def extract_frame_times(self, save=True, frame_rate=60, display=False, **kwargs)
147191 end = start + (rep ['intervals_1' ] - rep ['intervals_0' ])
148192 f2ttl = get_sync_fronts (sync , chmap ['frame2ttl' ])
149193 ts = f2ttl ['times' ][np .logical_and (f2ttl ['times' ] >= start , f2ttl ['times' ] < end )]
194+ if video_meta :
195+ _logger .debug ('Repeat %i: video duration: %.2fs, f2ttl duration: %.2f' ,
196+ i , video_meta .duration .seconds , ts [- 1 ] - ts [0 ])
150197
151198 # video_runtime is the video length reported by VLC.
152199 # As it was added later, the less accurate media player timestamps may be used if the former is not available
@@ -162,23 +209,24 @@ def extract_frame_times(self, save=True, frame_rate=60, display=False, **kwargs)
162209 # Find change points (black <-> white indices)
163210 x , = np .where (np .abs (np .diff (x )))
164211 # Include first frame as change point
165- x = np .r_ [0 , x ]
212+ x = np .r_ [0 , x + 1 ]
166213 # Synchronize the two by aligning flip times
167214 DRIFT_THRESHOLD_PPM = 50
168- Fs = self .timeline ['meta' ]['daqSampleRate' ]
169- fcn , drift = ibldsp .utils .sync_timestamps (sequence_times [x ], ts , tbin = 1 / Fs , linear = True )
215+ fcn , drift = ibldsp .utils .sync_timestamps (sequence_times [x ], ts , linear = True )
170216 # Log any major drift or raise if too large
171217 if np .abs (drift ) > DRIFT_THRESHOLD_PPM * 2 and x .size - ts .size > 100 :
172- raise SyncBpodFpgaException (f'sync cluster f*ck: drift = { drift :.2f} , changepoint difference = { x .size - ts .size } ' )
173- elif drift > DRIFT_THRESHOLD_PPM :
174- _logger .warning ('BPOD/FPGA synchronization shows values greater than %.2f ppm' ,
175- DRIFT_THRESHOLD_PPM )
218+ raise SyncBpodFpgaException (
219+ f'sync cluster f*ck: drift = { drift :.2f} , changepoint difference = { x .size - ts .size } ' )
220+ elif np .abs (drift ) > DRIFT_THRESHOLD_PPM :
221+ _logger .warning ('Frame synchronization shows values greater than %.2g ppm' , DRIFT_THRESHOLD_PPM )
222+ _logger .debug ('Frame synchronization drift: %.2f ppm' , drift )
176223
177224 # Get the frame times in timeline time
178225 frame_times [:len (sequence_times ), i ] = fcn (sequence_times )
179226
180227 # Trim down to length of repeat with most frames
181- frame_times = frame_times [:np .where (np .all (np .isnan (frame_times ), axis = 1 ))[0 ][0 ], :]
228+ if np .any (empty := np .all (np .isnan (frame_times ), axis = 1 )):
229+ frame_times = frame_times [:np .where (empty )[0 ][0 ], :]
182230
183231 if display :
184232 import matplotlib .pyplot as plt
@@ -191,16 +239,46 @@ def extract_frame_times(self, save=True, frame_rate=60, display=False, **kwargs)
191239 ax [0 ].title .set_text ('frame2ttl' )
192240 cmap = colormaps ['plasma' ]
193241 for i , times in enumerate (frame_times .T ):
242+ # Plot the sync sequence and sync'd frame times
194243 rgba = cmap (i / frame_times .shape [1 ])
195244 ax [1 ].plot (times , sync_sequence [:len (times )], c = rgba , label = f'{ i } ' )
245+ # Plot the f2ttl values
246+ idx = bpod_rep_starts [i ]
247+ start = bpod ['times' ][idx ]
248+ try :
249+ end = bpod ['times' ][idx + 1 ]
250+ except IndexError :
251+ end = start + (rep ['intervals_1' ] - rep ['intervals_0' ])
252+ mask = np .logical_and (f2ttl ['times' ] >= start , f2ttl ['times' ] < end )
253+ squares (f2ttl ['times' ][mask ], f2ttl ['polarities' ][mask ],
254+ yrange = [0 , 1 ], ax = ax [1 ], linestyle = ':' , color = 'k' )
196255 ax [1 ].title .set_text ('aligned sync square sequence' )
197256 ax [1 ].set_yticks ((0 , 1 ))
198257 ax [1 ].set_yticklabels ([- 1 , 1 ])
199258 plt .legend (markerfirst = False , title = 'repeat #' , loc = 'upper right' , facecolor = 'white' )
259+
260+ # Check the sync sequence from the video
261+ if video_file .exists ():
262+ observed = self .load_sync_sequence_from_video (video_file )
263+ _ , ax = plt .subplots (2 , 1 , sharex = True )
264+ ax [0 ].title .set_text ('generated sync square sequence' )
265+ ax [0 ].plot (sync_sequence [:observed .size ])
266+ ax [1 ].title .set_text ('observed sync square sequence' )
267+ ax [1 ].plot (observed )
268+
269+ # resample the f2ttl sequence to the frame times
270+ # tts = ts-ts[0]
271+ # from scipy import interpolate
272+ # interp = interpolate.interp1d(tts, pol, kind = "nearest")
273+ # ttts = np.arange(0, tts[-1], 1/frame_rate)
274+ # ax[2].plot(interp(ttts))
275+ # squares(ts-ts[0], pol, ax=ax[2])
276+
200277 plt .show ()
201278
202279 if save :
203280 filename = self .session_path .joinpath (self .output_collection , '_sp_video.times.npy' )
281+ np .save (filename , frame_times )
204282 out_files = [filename ]
205283 else :
206284 out_files = []
0 commit comments