|
1 | 1 | """Extraction pipeline for Alejandro's learning_witten_dop project, task protocol _iblrig_tasks_FPChoiceWorld6.4.2""" |
2 | | -import logging |
3 | 2 | from inspect import getmembers, isfunction |
| 3 | +import logging |
4 | 4 |
|
5 | | -import numpy as np |
6 | 5 | import pandas as pd |
| 6 | +import numpy as np |
| 7 | +import scipy.interpolate |
| 8 | + |
7 | 9 | import one.alf.io as alfio |
8 | 10 | from one.alf.exceptions import ALFObjectNotFound |
9 | 11 | from one.alf.spec import is_session_path |
10 | | -from iblutil.util import Bunch |
11 | 12 |
|
12 | | -from ibllib.io.extractors.fibrephotometry import FibrePhotometry as BaseFibrePhotometry |
13 | | -from ibllib.io.extractors.fibrephotometry import DAQ_CHMAP, NEUROPHOTOMETRICS_LED_STATES |
14 | | -from ibllib.pipes.photometry_tasks import FibrePhotometryPreprocess as PhotometryPreprocess |
| 13 | +from ibllib.io.extractors.base import BaseExtractor |
| 14 | +from ibllib.io.raw_daq_loaders import load_channels_tdms, load_raw_daq_tdms |
| 15 | +from ibllib.io.extractors.training_trials import GoCueTriggerTimes |
| 16 | +from ibldsp.utils import rises, sync_timestamps |
| 17 | +from iblutil.util import Bunch |
15 | 18 | from ibllib.io import raw_daq_loaders |
16 | 19 | from ibllib.qc.base import QC |
17 | | -from scipy import interpolate |
| 20 | +from ibllib.pipes import base_tasks |
| 21 | + |
| 22 | +_logger = logging.getLogger('ibllib').getChild(__name__.split('.')[-1]) |
| 23 | + |
| 24 | +"""Data extraction from fibrephotometry DAQ files. |
| 25 | +
|
| 26 | +Below is the expected folder structure for a fibrephotometry session: |
| 27 | +
|
| 28 | + subject/ |
| 29 | + ├─ 2021-06-30/ |
| 30 | + │ ├─ 001/ |
| 31 | + │ │ ├─ raw_photometry_data/ |
| 32 | + │ │ │ │ ├─ _neurophotometrics_fpData.raw.pqt |
| 33 | + │ │ │ │ ├─ _neurophotometrics_fpData.channels.csv |
| 34 | + │ │ │ │ ├─ _mcc_DAQdata.raw.tdms |
| 35 | +
|
| 36 | +fpData.raw.pqt is a copy of the 'FPdata' file, the output of the Neuophotometrics Bonsai workflow. |
| 37 | +fpData.channels.csv is table of frame flags for deciphering LED and GPIO states. The default table, |
| 38 | +copied from the Neurophotometrics manual can be found in iblscripts/deploy/fppc/ |
| 39 | +_mcc_DAQdata.raw.tdms is the DAQ tdms file, containing the pulses from bpod and from the neurophotometrics system |
| 40 | +
|
| 41 | +Neurophotometrics FP3002 specific information. |
| 42 | +The light source map refers to the available LEDs on the system. |
| 43 | +The flags refers to the byte encoding of led states in the system. |
| 44 | +""" |
| 45 | +LIGHT_SOURCE_MAP = { |
| 46 | + 'color': ['None', 'Violet', 'Blue', 'Green'], |
| 47 | + 'wavelength': [0, 415, 470, 560], |
| 48 | + 'name': ['None', 'Isosbestic', 'GCaMP', 'RCaMP'], |
| 49 | +} |
| 50 | + |
| 51 | +NEUROPHOTOMETRICS_LED_STATES = { |
| 52 | + 'Condition': { |
| 53 | + 0: 'No additional signal', |
| 54 | + 1: 'Output 1 signal HIGH', |
| 55 | + 2: 'Output 0 signal HIGH', |
| 56 | + 3: 'Stimulation ON', |
| 57 | + 4: 'GPIO Line 2 HIGH', |
| 58 | + 5: 'GPIO Line 3 HIGH', |
| 59 | + 6: 'Input 1 HIGH', |
| 60 | + 7: 'Input 0 HIGH', |
| 61 | + 8: 'Output 0 signal HIGH + Stimulation', |
| 62 | + 9: 'Output 0 signal HIGH + Input 0 signal HIGH', |
| 63 | + 10: 'Input 0 signal HIGH + Stimulation', |
| 64 | + 11: 'Output 0 HIGH + Input 0 HIGH + Stimulation', |
| 65 | + }, |
| 66 | + 'No LED ON': {0: 0, 1: 8, 2: 16, 3: 32, 4: 64, 5: 128, 6: 256, 7: 512, 8: 48, 9: 528, 10: 544, 11: 560}, |
| 67 | + 'L415': {0: 1, 1: 9, 2: 17, 3: 33, 4: 65, 5: 129, 6: 257, 7: 513, 8: 49, 9: 529, 10: 545, 11: 561}, |
| 68 | + 'L470': {0: 2, 1: 10, 2: 18, 3: 34, 4: 66, 5: 130, 6: 258, 7: 514, 8: 50, 9: 530, 10: 546, 11: 562}, |
| 69 | + 'L560': {0: 4, 1: 12, 2: 20, 3: 36, 4: 68, 5: 132, 6: 260, 7: 516, 8: 52, 9: 532, 10: 548, 11: 564} |
| 70 | +} |
18 | 71 |
|
19 | 72 | CHANNELS = pd.DataFrame.from_dict(NEUROPHOTOMETRICS_LED_STATES) |
| 73 | +DAQ_CHMAP = {"photometry": 'AI0', 'bpod': 'AI1'} |
| 74 | +V_THRESHOLD = 3 |
| 75 | + |
| 76 | + |
| 77 | +def sync_photometry_to_daq(vdaq, fs, df_photometry, chmap=DAQ_CHMAP, v_threshold=V_THRESHOLD): |
| 78 | + """ |
| 79 | + :param vdaq: dictionary of daq traces. |
| 80 | + :param fs: sampling frequency |
| 81 | + :param df_photometry: |
| 82 | + :param chmap: |
| 83 | + :param v_threshold: |
| 84 | + :return: |
| 85 | + """ |
| 86 | + # here we take the flag that is the most common |
| 87 | + daq_frames, tag_daq_frames = read_daq_timestamps(vdaq=vdaq, v_threshold=v_threshold) |
| 88 | + nf = np.minimum(tag_daq_frames.size, df_photometry['Input0'].size) |
| 89 | + |
| 90 | + # we compute the framecounter for the DAQ, and match the bpod up state frame by frame for different shifts |
| 91 | + # the shift that minimizes the mismatch is usually good |
| 92 | + df = np.median(np.diff(df_photometry['Timestamp'])) |
| 93 | + fc = np.cumsum(np.round(np.diff(daq_frames) / fs / df).astype(np.int32)) - 1 # this is a daq frame counter |
| 94 | + fc = fc[fc < (nf - 1)] |
| 95 | + max_shift = 300 |
| 96 | + error = np.zeros(max_shift * 2 + 1) |
| 97 | + shifts = np.arange(-max_shift, max_shift + 1) |
| 98 | + for i, shift in enumerate(shifts): |
| 99 | + rolled_fp = np.roll(df_photometry['Input0'].values[fc], shift) |
| 100 | + error[i] = np.sum(np.abs(rolled_fp - tag_daq_frames[:fc.size])) |
| 101 | + # a negative shift means that the DAQ is ahead of the photometry and that the DAQ misses frame at the beginning |
| 102 | + frame_shift = shifts[np.argmax(-error)] |
| 103 | + if np.sign(frame_shift) == -1: |
| 104 | + ifp = fc[np.abs(frame_shift):] |
| 105 | + elif np.sign(frame_shift) == 0: |
| 106 | + ifp = fc |
| 107 | + elif np.sign(frame_shift) == 1: |
| 108 | + ifp = fc[:-np.abs(frame_shift)] |
| 109 | + t_photometry = df_photometry['Timestamp'].values[ifp] |
| 110 | + t_daq = daq_frames[:ifp.size] / fs |
| 111 | + # import matplotlib.pyplot as plt |
| 112 | + # plt.plot(shifts, -error) |
| 113 | + fcn_fp2daq = scipy.interpolate.interp1d(t_photometry, t_daq, fill_value='extrapolate') |
| 114 | + drift_ppm = (np.polyfit(t_daq, t_photometry, 1)[0] - 1) * 1e6 |
| 115 | + if drift_ppm > 120: |
| 116 | + _logger.warning(f"drift photometry to DAQ PPM: {drift_ppm}") |
| 117 | + else: |
| 118 | + _logger.info(f"drift photometry to DAQ PPM: {drift_ppm}") |
| 119 | + # here is a bunch of safeguards |
| 120 | + assert np.unique(np.diff(df_photometry['FrameCounter'])).size == 1 # checks that there are no missed frames on photo |
| 121 | + assert np.abs(frame_shift) <= 5 # it's always the end frames that are missing |
| 122 | + assert np.abs(drift_ppm) < 60 |
| 123 | + ts_daq = fcn_fp2daq(df_photometry['Timestamp'].values) # those are the timestamps in daq time |
| 124 | + return ts_daq, fcn_fp2daq, drift_ppm |
| 125 | + |
| 126 | + |
| 127 | +def read_daq_voltage(daq_file, chmap=DAQ_CHMAP): |
| 128 | + channel_names = [c.name for c in load_raw_daq_tdms(daq_file)['Analog'].channels()] |
| 129 | + assert all([v in channel_names for v in chmap.values()]), "Missing channel" |
| 130 | + vdaq, fs = load_channels_tdms(daq_file, chmap=chmap) |
| 131 | + vdaq = {k: v - np.median(v) for k, v in vdaq.items()} |
| 132 | + return vdaq, fs |
| 133 | + |
| 134 | + |
| 135 | +def read_daq_timestamps(vdaq, v_threshold=V_THRESHOLD): |
| 136 | + """ |
| 137 | + From a tdms daq file, extracts the photometry frames and their tagging. |
| 138 | + :param vsaq: dictionary of the voltage traces from the DAQ. Each item has a key describing |
| 139 | + the channel as per the channel map, and contains a single voltage trace. |
| 140 | + :param v_threshold: |
| 141 | + :return: |
| 142 | + """ |
| 143 | + daq_frames = rises(vdaq['photometry'], step=v_threshold, analog=True) |
| 144 | + if daq_frames.size == 0: |
| 145 | + daq_frames = rises(-vdaq['photometry'], step=v_threshold, analog=True) |
| 146 | + _logger.warning(f'No photometry pulses detected, attempting to reverse voltage and detect again,' |
| 147 | + f'found {daq_frames.size} in reverse voltage. CHECK YOUR FP WIRING TO THE DAQ !!') |
| 148 | + tagged_frames = vdaq['bpod'][daq_frames] > v_threshold |
| 149 | + return daq_frames, tagged_frames |
| 150 | + |
| 151 | + |
| 152 | +def check_timestamps(daq_file, photometry_file, tolerance=20, chmap=DAQ_CHMAP, v_threshold=V_THRESHOLD): |
| 153 | + """ |
| 154 | + Reads data file and checks that the number of timestamps check out with a tolerance of n_frames |
| 155 | + :param daq_file: |
| 156 | + :param photometry_file: |
| 157 | + :param tolerance: number of acceptable missing frames between the daq and the photometry file |
| 158 | + :param chmap: |
| 159 | + :param v_threshold: |
| 160 | + :return: None |
| 161 | + """ |
| 162 | + df_photometry = pd.read_csv(photometry_file) |
| 163 | + v, fs = read_daq_voltage(daq_file=daq_file, chmap=chmap) |
| 164 | + daq_frames, _ = read_daq_timestamps(vdaq=v, v_threshold=v_threshold) |
| 165 | + assert (daq_frames.shape[0] - df_photometry.shape[0]) < tolerance |
| 166 | + _logger.info(f"{daq_frames.shape[0] - df_photometry.shape[0]} frames difference, " |
| 167 | + f"{'/'.join(daq_file.parts[-2:])}: {daq_frames.shape[0]} frames, " |
| 168 | + f"{'/'.join(photometry_file.parts[-2:])}: {df_photometry.shape[0]}") |
| 169 | + |
| 170 | + |
| 171 | +class BaseFibrePhotometry(BaseExtractor): |
| 172 | + """ |
| 173 | + FibrePhotometry(self.session_path, collection=self.collection) |
| 174 | + """ |
| 175 | + save_names = ('photometry.signal.pqt') |
| 176 | + var_names = ('df_out') |
| 177 | + |
| 178 | + def __init__(self, *args, collection='raw_photometry_data', **kwargs): |
| 179 | + """An extractor for all Neurophotometrics fibrephotometry data""" |
| 180 | + self.collection = collection |
| 181 | + super().__init__(*args, **kwargs) |
20 | 182 |
|
21 | | -_logger = logging.getLogger('ibllib').getChild(__name__.split('.')[-1]) |
| 183 | + @staticmethod |
| 184 | + def _channel_meta(light_source_map=None): |
| 185 | + """ |
| 186 | + Return table of light source wavelengths and corresponding colour labels. |
| 187 | +
|
| 188 | + Parameters |
| 189 | + ---------- |
| 190 | + light_source_map : dict |
| 191 | + An optional map of light source wavelengths (nm) used and their corresponding colour name. |
| 192 | +
|
| 193 | + Returns |
| 194 | + ------- |
| 195 | + pandas.DataFrame |
| 196 | + A sorted table of wavelength and colour name. |
| 197 | + """ |
| 198 | + light_source_map = light_source_map or LIGHT_SOURCE_MAP |
| 199 | + meta = pd.DataFrame.from_dict(light_source_map) |
| 200 | + meta.index.rename('channel_id', inplace=True) |
| 201 | + return meta |
| 202 | + |
| 203 | + def _extract(self, light_source_map=None, collection=None, regions=None, **kwargs): |
| 204 | + """ |
| 205 | +
|
| 206 | + Parameters |
| 207 | + ---------- |
| 208 | + regions: list of str |
| 209 | + The list of regions to extract. If None extracts all columns containing "Region". Defaults to None. |
| 210 | + light_source_map : dict |
| 211 | + An optional map of light source wavelengths (nm) used and their corresponding colour name. |
| 212 | + collection: str / pathlib.Path |
| 213 | + An optional relative path from the session root folder to find the raw photometry data. |
| 214 | + Defaults to `raw_photometry_data` |
| 215 | +
|
| 216 | + Returns |
| 217 | + ------- |
| 218 | + numpy.ndarray |
| 219 | + A 1D array of signal values. |
| 220 | + numpy.ndarray |
| 221 | + A 1D array of ints corresponding to the active light source during a given frame. |
| 222 | + pandas.DataFrame |
| 223 | + A table of intensity for each region, with associated times, wavelengths, names and colors |
| 224 | + """ |
| 225 | + collection = collection or self.collection |
| 226 | + fp_data = alfio.load_object(self.session_path / collection, 'fpData') |
| 227 | + ts = self.extract_timestamps(fp_data['raw'], **kwargs) |
| 228 | + |
| 229 | + # Load channels and |
| 230 | + channel_meta_map = self._channel_meta(kwargs.get('light_source_map')) |
| 231 | + led_states = fp_data.get('channels', pd.DataFrame(NEUROPHOTOMETRICS_LED_STATES)) |
| 232 | + led_states = led_states.set_index('Condition') |
| 233 | + # Extract signal columns into 2D array |
| 234 | + regions = regions or [k for k in fp_data['raw'].keys() if 'Region' in k] |
| 235 | + out_df = fp_data['raw'].filter(items=regions, axis=1).sort_index(axis=1) |
| 236 | + out_df['times'] = ts |
| 237 | + out_df['wavelength'] = np.nan |
| 238 | + out_df['name'] = '' |
| 239 | + out_df['color'] = '' |
| 240 | + # Extract channel index |
| 241 | + states = fp_data['raw'].get('LedState', fp_data['raw'].get('Flags', None)) |
| 242 | + for state in states.unique(): |
| 243 | + ir, ic = np.where(led_states == state) |
| 244 | + if ic.size == 0: |
| 245 | + continue |
| 246 | + for cn in ['name', 'color', 'wavelength']: |
| 247 | + out_df.loc[states == state, cn] = channel_meta_map.iloc[ic[0]][cn] |
| 248 | + return out_df |
| 249 | + |
| 250 | + def extract_timestamps(self, fp_data, **kwargs): |
| 251 | + """Extract the photometry.timestamps array. |
| 252 | +
|
| 253 | + This depends on the DAQ and task synchronization protocol. |
| 254 | +
|
| 255 | + Parameters |
| 256 | + ---------- |
| 257 | + fp_data : dict |
| 258 | + A Bunch of raw fibrephotometry data, with the keys ('raw', 'channels'). |
| 259 | +
|
| 260 | + Returns |
| 261 | + ------- |
| 262 | + numpy.ndarray |
| 263 | + An array of timestamps, one per frame. |
| 264 | + """ |
| 265 | + daq_file = next(self.session_path.joinpath(self.collection).glob('*.tdms')) |
| 266 | + vdaq, fs = read_daq_voltage(daq_file, chmap=DAQ_CHMAP) |
| 267 | + ts, fcn_daq2_, drift_ppm = sync_photometry_to_daq( |
| 268 | + vdaq=vdaq, fs=fs, df_photometry=fp_data, v_threshold=V_THRESHOLD) |
| 269 | + gc_bpod, _ = GoCueTriggerTimes(session_path=self.session_path).extract(task_collection='raw_behavior_data', save=False) |
| 270 | + gc_daq = rises(vdaq['bpod']) |
| 271 | + |
| 272 | + fcn_daq2_bpod, drift_ppm, idaq, ibp = sync_timestamps( |
| 273 | + rises(vdaq['bpod']) / fs, gc_bpod, return_indices=True) |
| 274 | + assert drift_ppm < 100, f"Drift between bpod and daq is above 100 ppm: {drift_ppm}" |
| 275 | + assert (gc_daq.size - idaq.size) < 5, "Bpod and daq synchronisation failed as too few" \ |
| 276 | + "events could be matched" |
| 277 | + ts = fcn_daq2_bpod(ts) |
| 278 | + return ts |
22 | 279 |
|
23 | 280 |
|
24 | 281 | # upload to the session endpoint, qc per regions |
@@ -295,15 +552,33 @@ def sync_timestamps(daq_data, fp_data, trials): |
295 | 552 |
|
296 | 553 | use_times = ~fp_data['bpod_times'].isna() |
297 | 554 |
|
298 | | - fcn = interpolate.interp1d(fp_data['Timestamp'][use_times].values, fp_data['bpod_times'][use_times].values, |
299 | | - fill_value="extrapolate") |
| 555 | + fcn = scipy.interpolate.interp1d( |
| 556 | + fp_data['Timestamp'][use_times].values, fp_data['bpod_times'][use_times].values, fill_value="extrapolate") |
300 | 557 |
|
301 | 558 | ts = fcn(fp_data['Timestamp'].values) |
302 | 559 |
|
303 | 560 | return ts |
304 | 561 |
|
305 | 562 |
|
306 | | -class FibrePhotometryPreprocess(PhotometryPreprocess): |
| 563 | +class FibrePhotometryPreprocess(base_tasks.DynamicTask): |
| 564 | + @property |
| 565 | + def signature(self): |
| 566 | + signature = { |
| 567 | + 'input_files': [('_mcc_DAQdata.raw.tdms', self.device_collection, True), |
| 568 | + ('_neurophotometrics_fpData.raw.pqt', self.device_collection, True)], |
| 569 | + 'output_files': [('photometry.signal.pqt', 'alf/photometry', True)] |
| 570 | + } |
| 571 | + return signature |
| 572 | + |
| 573 | + priority = 90 |
| 574 | + level = 1 |
| 575 | + |
| 576 | + def __init__(self, session_path, regions=None, **kwargs): |
| 577 | + super().__init__(session_path, **kwargs) |
| 578 | + # Task collection (this needs to be specified in the task kwargs) |
| 579 | + self.collection = self.get_task_collection(kwargs.get('collection', None)) |
| 580 | + self.device_collection = self.get_device_collection('photometry', device_collection='raw_photometry_data') |
| 581 | + self.regions = regions |
307 | 582 |
|
308 | 583 | def _run(self, **kwargs): |
309 | 584 | _, out_files = FibrePhotometry(self.session_path, collection=self.device_collection).extract( |
|
0 commit comments