|
| 1 | +import logging |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import pandas as pd |
| 5 | +import one.alf.io as alfio |
| 6 | +import ibldsp.utils |
| 7 | +from iblutil.spacer import Spacer |
| 8 | + |
| 9 | +from ibllib.pipes.base_tasks import BehaviourTask |
| 10 | +from ibllib.exceptions import SyncBpodFpgaException |
| 11 | +from ibllib.io.extractors.ephys_fpga import get_protocol_period, get_sync_fronts |
| 12 | +from ibllib.io.raw_daq_loaders import load_timeline_sync_and_chmap |
| 13 | +from ibllib.io.extractors.mesoscope import plot_timeline |
| 14 | + |
| 15 | +_logger = logging.getLogger('ibllib').getChild(__name__) |
| 16 | + |
| 17 | + |
| 18 | +class PassiveVideoTimeline(BehaviourTask): |
| 19 | + """Extraction task for _sp_passiveVideo protocol.""" |
| 20 | + priority = 90 |
| 21 | + job_size = 'small' |
| 22 | + |
| 23 | + @property |
| 24 | + def signature(self): |
| 25 | + signature = {} |
| 26 | + signature['input_files'] = [ |
| 27 | + ('_sp_taskData.raw.*', self.collection, True), # TODO Create dataset type? |
| 28 | + ('_iblrig_taskSettings.raw.*', self.collection, True), |
| 29 | + (f'_{self.sync_namespace}_DAQdata.raw.npy', self.sync_collection, True), |
| 30 | + (f'_{self.sync_namespace}_DAQdata.timestamps.npy', self.sync_collection, True), |
| 31 | + (f'_{self.sync_namespace}_DAQdata.meta.json', self.sync_collection, True), |
| 32 | + ] |
| 33 | + signature['output_files'] = [('_sp_video.times.npy', self.output_collection, True),] |
| 34 | + return signature |
| 35 | + |
| 36 | + def generate_sync_sequence(seed=1234, ns=3600, res=8): |
| 37 | + """Generate the sync square frame colour sequence. |
| 38 | +
|
| 39 | + Instead of changing each frame, the video sync square flips between black and white |
| 40 | + in a particular sequence defined within this function (in random multiples of res). |
| 41 | +
|
| 42 | + Parameters |
| 43 | + ---------- |
| 44 | + ns : int |
| 45 | + Related to the length in frames of the sequence (n_frames = ns * res). |
| 46 | + res : int |
| 47 | + The minimum number of sequential frames in each colour state. The N sequential frames |
| 48 | + is a multiple of this number. |
| 49 | + seed : int, optional |
| 50 | + The numpy random seed integer, by default 1234 |
| 51 | +
|
| 52 | + Returns |
| 53 | + ------- |
| 54 | + numpy.array |
| 55 | + An integer array of sync square states (one per frame) where 0 represents black and 1 |
| 56 | + represents white. |
| 57 | + """ |
| 58 | + state = np.random.get_state() |
| 59 | + try: |
| 60 | + np.random.seed(1234) |
| 61 | + seq = np.tile(np.random.random(ns), (res, 1)).T.flatten() |
| 62 | + return (seq > .5).astype(np.int8) |
| 63 | + finally: |
| 64 | + np.random.set_state(state) |
| 65 | + |
| 66 | + def extract_frame_times(self, save=True, frame_rate=60, display=False, **kwargs): |
| 67 | + """Extract the Bpod trials data and Timeline acquired signals. |
| 68 | +
|
| 69 | + Sync requires three steps: |
| 70 | + 1. Find protocol period using spacers |
| 71 | + 2. Find each video repeat with Bpod out |
| 72 | + 3. Find frame times with frame2ttl |
| 73 | +
|
| 74 | + Parameters |
| 75 | + ---------- |
| 76 | + save : bool, optional |
| 77 | + Whether to save the video frame times to file, by default True. |
| 78 | + frame_rate : int, optional |
| 79 | + The frame rate of the video presented, by default 60. |
| 80 | + display : bool, optional |
| 81 | + When true, plot the aligned frame times. By default False. |
| 82 | +
|
| 83 | + Returns |
| 84 | + ------- |
| 85 | + numpy.array |
| 86 | + The extracted frame times where N rows represent the number of frames and M columns |
| 87 | + represent the number of video repeats. The exact number of frames is not known and |
| 88 | + NaN values represent shorter video repeats. |
| 89 | + pathlib.Path |
| 90 | + The file path of the saved video times, or None if save=False. |
| 91 | +
|
| 92 | + Raises |
| 93 | + ------ |
| 94 | + ValueError |
| 95 | + The `protocol_number` property is None and no `tmin` or `tmax` values were passed as |
| 96 | + keyword arguments. |
| 97 | + SyncBpodFpgaException |
| 98 | + The synchronization of frame times was likely unsuccessful. |
| 99 | + """ |
| 100 | + _, (p,), _ = self.input_files[0].find_files(self.session_path) |
| 101 | + # Load raw data |
| 102 | + proc_data = pd.read_parquet(p) |
| 103 | + sync_path = self.session_path / self.sync_collection |
| 104 | + self.timeline = alfio.load_object(sync_path, 'DAQdata', namespace='timeline') |
| 105 | + sync, chmap = load_timeline_sync_and_chmap(sync_path, timeline=self.timeline) |
| 106 | + |
| 107 | + bpod = get_sync_fronts(sync, chmap['bpod']) |
| 108 | + # Get the spacer times for this protocol |
| 109 | + if any(arg in kwargs for arg in ('tmin', 'tmax')): |
| 110 | + tmin, tmax = kwargs.get('tmin'), kwargs.get('tmax') |
| 111 | + elif self.protocol_number is None: |
| 112 | + raise ValueError('Protocol number not defined') |
| 113 | + else: |
| 114 | + # The spacers are TTLs generated by Bpod at the start of each protocol |
| 115 | + tmin, tmax = get_protocol_period(self.session_path, self.protocol_number, bpod) |
| 116 | + tmin += (Spacer().times[-1] + Spacer().tup + 0.05) # exclude spacer itself |
| 117 | + |
| 118 | + # Remove unnecessary data from sync |
| 119 | + selection = np.logical_and( |
| 120 | + sync['times'] <= (tmax if tmax is not None else sync['times'][-1]), |
| 121 | + sync['times'] >= (tmin if tmin is not None else sync['times'][0]), |
| 122 | + ) |
| 123 | + sync = alfio.AlfBunch({k: v[selection] for k, v in sync.items()}) |
| 124 | + bpod = get_sync_fronts(sync, chmap['bpod']) |
| 125 | + _logger.debug('Protocol period from %.2fs to %.2fs (~%.0f min duration)', |
| 126 | + *sync['times'][[0, -1]], np.diff(sync['times'][[0, -1]]) / 60) |
| 127 | + |
| 128 | + # For each period of video playback the Bpod should output voltage HIGH |
| 129 | + bpod_rep_starts, = np.where(bpod['polarities'] == 1) |
| 130 | + _logger.info('N video repeats: %i; N Bpod pulses: %i', len(proc_data), len(bpod_rep_starts)) |
| 131 | + assert len(bpod_rep_starts) == len(proc_data) |
| 132 | + |
| 133 | + # These durations are longer than video actually played and will be cut down after |
| 134 | + durations = (proc_data['intervals_1'] - proc_data['intervals_0']).values |
| 135 | + max_n_frames = np.max(np.ceil(durations * frame_rate).astype(int)) |
| 136 | + frame_times = np.full((max_n_frames, len(proc_data)), np.nan) |
| 137 | + |
| 138 | + sync_sequence = kwargs.get('sync_sequence', self.generate_sync_sequence()) |
| 139 | + for i, rep in proc_data.iterrows(): |
| 140 | + # Get the frame2ttl times for the video presentation |
| 141 | + idx = bpod_rep_starts[i] |
| 142 | + start = bpod['times'][idx] |
| 143 | + try: |
| 144 | + end = bpod['times'][idx + 1] |
| 145 | + except IndexError: |
| 146 | + _logger.warning('Final Bpod LOW missing') |
| 147 | + end = start + (rep['intervals_1'] - rep['intervals_0']) |
| 148 | + f2ttl = get_sync_fronts(sync, chmap['frame2ttl']) |
| 149 | + ts = f2ttl['times'][np.logical_and(f2ttl['times'] >= start, f2ttl['times'] < end)] |
| 150 | + |
| 151 | + # video_runtime is the video length reported by VLC. |
| 152 | + # As it was added later, the less accurate media player timestamps may be used if the former is not available |
| 153 | + duration = rep.get('video_runtime') or (rep['MediaPlayerEndReached'] - rep['MediaPlayerPlaying']) |
| 154 | + # Start the sync sequence times at the start of the first frame2ttl flip (ts[0]) as this makes syncing more |
| 155 | + # performant because the offset is small |
| 156 | + sequence_times = np.arange(0, duration, 1 / frame_rate) |
| 157 | + sequence_times += ts[0] |
| 158 | + # The below assertion could be caused by an incorrect frame rate or sync sequence |
| 159 | + assert sequence_times.size <= sync_sequence.size, 'video duration appears longer than sync sequence' |
| 160 | + # Keep only the part of the sequence that was shown |
| 161 | + x = sync_sequence[:len(sequence_times)] |
| 162 | + # Find change points (black <-> white indices) |
| 163 | + x, = np.where(np.abs(np.diff(x))) |
| 164 | + # Include first frame as change point |
| 165 | + x = np.r_[0, x] |
| 166 | + # Synchronize the two by aligning flip times |
| 167 | + DRIFT_THRESHOLD_PPM = 50 |
| 168 | + Fs = self.timeline['meta']['daqSampleRate'] |
| 169 | + fcn, drift = ibldsp.utils.sync_timestamps(sequence_times[x], ts, tbin=1 / Fs, linear=True) |
| 170 | + # Log any major drift or raise if too large |
| 171 | + if np.abs(drift) > DRIFT_THRESHOLD_PPM * 2 and x.size - ts.size > 100: |
| 172 | + raise SyncBpodFpgaException(f'sync cluster f*ck: drift = {drift:.2f}, changepoint difference = {x.size - ts.size}') |
| 173 | + elif drift > DRIFT_THRESHOLD_PPM: |
| 174 | + _logger.warning('BPOD/FPGA synchronization shows values greater than %.2f ppm', |
| 175 | + DRIFT_THRESHOLD_PPM) |
| 176 | + |
| 177 | + # Get the frame times in timeline time |
| 178 | + frame_times[:len(sequence_times), i] = fcn(sequence_times) |
| 179 | + |
| 180 | + # Trim down to length of repeat with most frames |
| 181 | + frame_times = frame_times[:np.where(np.all(np.isnan(frame_times), axis=1))[0][0], :] |
| 182 | + |
| 183 | + if display: |
| 184 | + import matplotlib.pyplot as plt |
| 185 | + from matplotlib import colormaps |
| 186 | + from ibllib.plots import squares |
| 187 | + plot_timeline(self.timeline, channels=['bpod', 'frame2ttl']) |
| 188 | + _, ax = plt.subplots(2, 1, sharex=True) |
| 189 | + squares(f2ttl['times'], f2ttl['polarities'], ax=ax[0]) |
| 190 | + ax[0].set_yticks((-1, 1)) |
| 191 | + ax[0].title.set_text('frame2ttl') |
| 192 | + cmap = colormaps['plasma'] |
| 193 | + for i, times in enumerate(frame_times.T): |
| 194 | + rgba = cmap(i / frame_times.shape[1]) |
| 195 | + ax[1].plot(times, sync_sequence[:len(times)], c=rgba, label=f'{i}') |
| 196 | + ax[1].title.set_text('aligned sync square sequence') |
| 197 | + ax[1].set_yticks((0, 1)) |
| 198 | + ax[1].set_yticklabels([-1, 1]) |
| 199 | + plt.legend(markerfirst=False, title='repeat #', loc='upper right', facecolor='white') |
| 200 | + plt.show() |
| 201 | + |
| 202 | + if save: |
| 203 | + filename = self.session_path.joinpath(self.output_collection, '_sp_video.times.npy') |
| 204 | + out_files = [filename] |
| 205 | + else: |
| 206 | + out_files = [] |
| 207 | + |
| 208 | + return {'video_times': frame_times}, out_files |
| 209 | + |
| 210 | + def run_qc(self, **_): |
| 211 | + raise NotImplementedError |
| 212 | + |
| 213 | + def _run(self, save=True, **kwargs): |
| 214 | + _, output_files = self.extract_frame_times(save=save, **kwargs) |
| 215 | + return output_files |
0 commit comments