Skip to content

Commit 61aa855

Browse files
committed
move ibllib deprecated code here for reference
1 parent c3ae3e2 commit 61aa855

File tree

1 file changed

+286
-11
lines changed

1 file changed

+286
-11
lines changed

projects/biased_fibrephotometry.py

Lines changed: 286 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,281 @@
11
"""Extraction pipeline for Alejandro's learning_witten_dop project, task protocol _iblrig_tasks_FPChoiceWorld6.4.2"""
2-
import logging
32
from inspect import getmembers, isfunction
3+
import logging
44

5-
import numpy as np
65
import pandas as pd
6+
import numpy as np
7+
import scipy.interpolate
8+
79
import one.alf.io as alfio
810
from one.alf.exceptions import ALFObjectNotFound
911
from one.alf.spec import is_session_path
10-
from iblutil.util import Bunch
1112

12-
from ibllib.io.extractors.fibrephotometry import FibrePhotometry as BaseFibrePhotometry
13-
from ibllib.io.extractors.fibrephotometry import DAQ_CHMAP, NEUROPHOTOMETRICS_LED_STATES
14-
from ibllib.pipes.photometry_tasks import FibrePhotometryPreprocess as PhotometryPreprocess
13+
from ibllib.io.extractors.base import BaseExtractor
14+
from ibllib.io.raw_daq_loaders import load_channels_tdms, load_raw_daq_tdms
15+
from ibllib.io.extractors.training_trials import GoCueTriggerTimes
16+
from ibldsp.utils import rises, sync_timestamps
17+
from iblutil.util import Bunch
1518
from ibllib.io import raw_daq_loaders
1619
from ibllib.qc.base import QC
17-
from scipy import interpolate
20+
from ibllib.pipes import base_tasks
21+
22+
_logger = logging.getLogger('ibllib').getChild(__name__.split('.')[-1])
23+
24+
"""Data extraction from fibrephotometry DAQ files.
25+
26+
Below is the expected folder structure for a fibrephotometry session:
27+
28+
subject/
29+
├─ 2021-06-30/
30+
│ ├─ 001/
31+
│ │ ├─ raw_photometry_data/
32+
│ │ │ │ ├─ _neurophotometrics_fpData.raw.pqt
33+
│ │ │ │ ├─ _neurophotometrics_fpData.channels.csv
34+
│ │ │ │ ├─ _mcc_DAQdata.raw.tdms
35+
36+
fpData.raw.pqt is a copy of the 'FPdata' file, the output of the Neuophotometrics Bonsai workflow.
37+
fpData.channels.csv is table of frame flags for deciphering LED and GPIO states. The default table,
38+
copied from the Neurophotometrics manual can be found in iblscripts/deploy/fppc/
39+
_mcc_DAQdata.raw.tdms is the DAQ tdms file, containing the pulses from bpod and from the neurophotometrics system
40+
41+
Neurophotometrics FP3002 specific information.
42+
The light source map refers to the available LEDs on the system.
43+
The flags refers to the byte encoding of led states in the system.
44+
"""
45+
LIGHT_SOURCE_MAP = {
46+
'color': ['None', 'Violet', 'Blue', 'Green'],
47+
'wavelength': [0, 415, 470, 560],
48+
'name': ['None', 'Isosbestic', 'GCaMP', 'RCaMP'],
49+
}
50+
51+
NEUROPHOTOMETRICS_LED_STATES = {
52+
'Condition': {
53+
0: 'No additional signal',
54+
1: 'Output 1 signal HIGH',
55+
2: 'Output 0 signal HIGH',
56+
3: 'Stimulation ON',
57+
4: 'GPIO Line 2 HIGH',
58+
5: 'GPIO Line 3 HIGH',
59+
6: 'Input 1 HIGH',
60+
7: 'Input 0 HIGH',
61+
8: 'Output 0 signal HIGH + Stimulation',
62+
9: 'Output 0 signal HIGH + Input 0 signal HIGH',
63+
10: 'Input 0 signal HIGH + Stimulation',
64+
11: 'Output 0 HIGH + Input 0 HIGH + Stimulation',
65+
},
66+
'No LED ON': {0: 0, 1: 8, 2: 16, 3: 32, 4: 64, 5: 128, 6: 256, 7: 512, 8: 48, 9: 528, 10: 544, 11: 560},
67+
'L415': {0: 1, 1: 9, 2: 17, 3: 33, 4: 65, 5: 129, 6: 257, 7: 513, 8: 49, 9: 529, 10: 545, 11: 561},
68+
'L470': {0: 2, 1: 10, 2: 18, 3: 34, 4: 66, 5: 130, 6: 258, 7: 514, 8: 50, 9: 530, 10: 546, 11: 562},
69+
'L560': {0: 4, 1: 12, 2: 20, 3: 36, 4: 68, 5: 132, 6: 260, 7: 516, 8: 52, 9: 532, 10: 548, 11: 564}
70+
}
1871

1972
CHANNELS = pd.DataFrame.from_dict(NEUROPHOTOMETRICS_LED_STATES)
73+
DAQ_CHMAP = {"photometry": 'AI0', 'bpod': 'AI1'}
74+
V_THRESHOLD = 3
75+
76+
77+
def sync_photometry_to_daq(vdaq, fs, df_photometry, chmap=DAQ_CHMAP, v_threshold=V_THRESHOLD):
78+
"""
79+
:param vdaq: dictionary of daq traces.
80+
:param fs: sampling frequency
81+
:param df_photometry:
82+
:param chmap:
83+
:param v_threshold:
84+
:return:
85+
"""
86+
# here we take the flag that is the most common
87+
daq_frames, tag_daq_frames = read_daq_timestamps(vdaq=vdaq, v_threshold=v_threshold)
88+
nf = np.minimum(tag_daq_frames.size, df_photometry['Input0'].size)
89+
90+
# we compute the framecounter for the DAQ, and match the bpod up state frame by frame for different shifts
91+
# the shift that minimizes the mismatch is usually good
92+
df = np.median(np.diff(df_photometry['Timestamp']))
93+
fc = np.cumsum(np.round(np.diff(daq_frames) / fs / df).astype(np.int32)) - 1 # this is a daq frame counter
94+
fc = fc[fc < (nf - 1)]
95+
max_shift = 300
96+
error = np.zeros(max_shift * 2 + 1)
97+
shifts = np.arange(-max_shift, max_shift + 1)
98+
for i, shift in enumerate(shifts):
99+
rolled_fp = np.roll(df_photometry['Input0'].values[fc], shift)
100+
error[i] = np.sum(np.abs(rolled_fp - tag_daq_frames[:fc.size]))
101+
# a negative shift means that the DAQ is ahead of the photometry and that the DAQ misses frame at the beginning
102+
frame_shift = shifts[np.argmax(-error)]
103+
if np.sign(frame_shift) == -1:
104+
ifp = fc[np.abs(frame_shift):]
105+
elif np.sign(frame_shift) == 0:
106+
ifp = fc
107+
elif np.sign(frame_shift) == 1:
108+
ifp = fc[:-np.abs(frame_shift)]
109+
t_photometry = df_photometry['Timestamp'].values[ifp]
110+
t_daq = daq_frames[:ifp.size] / fs
111+
# import matplotlib.pyplot as plt
112+
# plt.plot(shifts, -error)
113+
fcn_fp2daq = scipy.interpolate.interp1d(t_photometry, t_daq, fill_value='extrapolate')
114+
drift_ppm = (np.polyfit(t_daq, t_photometry, 1)[0] - 1) * 1e6
115+
if drift_ppm > 120:
116+
_logger.warning(f"drift photometry to DAQ PPM: {drift_ppm}")
117+
else:
118+
_logger.info(f"drift photometry to DAQ PPM: {drift_ppm}")
119+
# here is a bunch of safeguards
120+
assert np.unique(np.diff(df_photometry['FrameCounter'])).size == 1 # checks that there are no missed frames on photo
121+
assert np.abs(frame_shift) <= 5 # it's always the end frames that are missing
122+
assert np.abs(drift_ppm) < 60
123+
ts_daq = fcn_fp2daq(df_photometry['Timestamp'].values) # those are the timestamps in daq time
124+
return ts_daq, fcn_fp2daq, drift_ppm
125+
126+
127+
def read_daq_voltage(daq_file, chmap=DAQ_CHMAP):
128+
channel_names = [c.name for c in load_raw_daq_tdms(daq_file)['Analog'].channels()]
129+
assert all([v in channel_names for v in chmap.values()]), "Missing channel"
130+
vdaq, fs = load_channels_tdms(daq_file, chmap=chmap)
131+
vdaq = {k: v - np.median(v) for k, v in vdaq.items()}
132+
return vdaq, fs
133+
134+
135+
def read_daq_timestamps(vdaq, v_threshold=V_THRESHOLD):
136+
"""
137+
From a tdms daq file, extracts the photometry frames and their tagging.
138+
:param vsaq: dictionary of the voltage traces from the DAQ. Each item has a key describing
139+
the channel as per the channel map, and contains a single voltage trace.
140+
:param v_threshold:
141+
:return:
142+
"""
143+
daq_frames = rises(vdaq['photometry'], step=v_threshold, analog=True)
144+
if daq_frames.size == 0:
145+
daq_frames = rises(-vdaq['photometry'], step=v_threshold, analog=True)
146+
_logger.warning(f'No photometry pulses detected, attempting to reverse voltage and detect again,'
147+
f'found {daq_frames.size} in reverse voltage. CHECK YOUR FP WIRING TO THE DAQ !!')
148+
tagged_frames = vdaq['bpod'][daq_frames] > v_threshold
149+
return daq_frames, tagged_frames
150+
151+
152+
def check_timestamps(daq_file, photometry_file, tolerance=20, chmap=DAQ_CHMAP, v_threshold=V_THRESHOLD):
153+
"""
154+
Reads data file and checks that the number of timestamps check out with a tolerance of n_frames
155+
:param daq_file:
156+
:param photometry_file:
157+
:param tolerance: number of acceptable missing frames between the daq and the photometry file
158+
:param chmap:
159+
:param v_threshold:
160+
:return: None
161+
"""
162+
df_photometry = pd.read_csv(photometry_file)
163+
v, fs = read_daq_voltage(daq_file=daq_file, chmap=chmap)
164+
daq_frames, _ = read_daq_timestamps(vdaq=v, v_threshold=v_threshold)
165+
assert (daq_frames.shape[0] - df_photometry.shape[0]) < tolerance
166+
_logger.info(f"{daq_frames.shape[0] - df_photometry.shape[0]} frames difference, "
167+
f"{'/'.join(daq_file.parts[-2:])}: {daq_frames.shape[0]} frames, "
168+
f"{'/'.join(photometry_file.parts[-2:])}: {df_photometry.shape[0]}")
169+
170+
171+
class BaseFibrePhotometry(BaseExtractor):
172+
"""
173+
FibrePhotometry(self.session_path, collection=self.collection)
174+
"""
175+
save_names = ('photometry.signal.pqt')
176+
var_names = ('df_out')
177+
178+
def __init__(self, *args, collection='raw_photometry_data', **kwargs):
179+
"""An extractor for all Neurophotometrics fibrephotometry data"""
180+
self.collection = collection
181+
super().__init__(*args, **kwargs)
20182

21-
_logger = logging.getLogger('ibllib').getChild(__name__.split('.')[-1])
183+
@staticmethod
184+
def _channel_meta(light_source_map=None):
185+
"""
186+
Return table of light source wavelengths and corresponding colour labels.
187+
188+
Parameters
189+
----------
190+
light_source_map : dict
191+
An optional map of light source wavelengths (nm) used and their corresponding colour name.
192+
193+
Returns
194+
-------
195+
pandas.DataFrame
196+
A sorted table of wavelength and colour name.
197+
"""
198+
light_source_map = light_source_map or LIGHT_SOURCE_MAP
199+
meta = pd.DataFrame.from_dict(light_source_map)
200+
meta.index.rename('channel_id', inplace=True)
201+
return meta
202+
203+
def _extract(self, light_source_map=None, collection=None, regions=None, **kwargs):
204+
"""
205+
206+
Parameters
207+
----------
208+
regions: list of str
209+
The list of regions to extract. If None extracts all columns containing "Region". Defaults to None.
210+
light_source_map : dict
211+
An optional map of light source wavelengths (nm) used and their corresponding colour name.
212+
collection: str / pathlib.Path
213+
An optional relative path from the session root folder to find the raw photometry data.
214+
Defaults to `raw_photometry_data`
215+
216+
Returns
217+
-------
218+
numpy.ndarray
219+
A 1D array of signal values.
220+
numpy.ndarray
221+
A 1D array of ints corresponding to the active light source during a given frame.
222+
pandas.DataFrame
223+
A table of intensity for each region, with associated times, wavelengths, names and colors
224+
"""
225+
collection = collection or self.collection
226+
fp_data = alfio.load_object(self.session_path / collection, 'fpData')
227+
ts = self.extract_timestamps(fp_data['raw'], **kwargs)
228+
229+
# Load channels and
230+
channel_meta_map = self._channel_meta(kwargs.get('light_source_map'))
231+
led_states = fp_data.get('channels', pd.DataFrame(NEUROPHOTOMETRICS_LED_STATES))
232+
led_states = led_states.set_index('Condition')
233+
# Extract signal columns into 2D array
234+
regions = regions or [k for k in fp_data['raw'].keys() if 'Region' in k]
235+
out_df = fp_data['raw'].filter(items=regions, axis=1).sort_index(axis=1)
236+
out_df['times'] = ts
237+
out_df['wavelength'] = np.nan
238+
out_df['name'] = ''
239+
out_df['color'] = ''
240+
# Extract channel index
241+
states = fp_data['raw'].get('LedState', fp_data['raw'].get('Flags', None))
242+
for state in states.unique():
243+
ir, ic = np.where(led_states == state)
244+
if ic.size == 0:
245+
continue
246+
for cn in ['name', 'color', 'wavelength']:
247+
out_df.loc[states == state, cn] = channel_meta_map.iloc[ic[0]][cn]
248+
return out_df
249+
250+
def extract_timestamps(self, fp_data, **kwargs):
251+
"""Extract the photometry.timestamps array.
252+
253+
This depends on the DAQ and task synchronization protocol.
254+
255+
Parameters
256+
----------
257+
fp_data : dict
258+
A Bunch of raw fibrephotometry data, with the keys ('raw', 'channels').
259+
260+
Returns
261+
-------
262+
numpy.ndarray
263+
An array of timestamps, one per frame.
264+
"""
265+
daq_file = next(self.session_path.joinpath(self.collection).glob('*.tdms'))
266+
vdaq, fs = read_daq_voltage(daq_file, chmap=DAQ_CHMAP)
267+
ts, fcn_daq2_, drift_ppm = sync_photometry_to_daq(
268+
vdaq=vdaq, fs=fs, df_photometry=fp_data, v_threshold=V_THRESHOLD)
269+
gc_bpod, _ = GoCueTriggerTimes(session_path=self.session_path).extract(task_collection='raw_behavior_data', save=False)
270+
gc_daq = rises(vdaq['bpod'])
271+
272+
fcn_daq2_bpod, drift_ppm, idaq, ibp = sync_timestamps(
273+
rises(vdaq['bpod']) / fs, gc_bpod, return_indices=True)
274+
assert drift_ppm < 100, f"Drift between bpod and daq is above 100 ppm: {drift_ppm}"
275+
assert (gc_daq.size - idaq.size) < 5, "Bpod and daq synchronisation failed as too few" \
276+
"events could be matched"
277+
ts = fcn_daq2_bpod(ts)
278+
return ts
22279

23280

24281
# upload to the session endpoint, qc per regions
@@ -295,15 +552,33 @@ def sync_timestamps(daq_data, fp_data, trials):
295552

296553
use_times = ~fp_data['bpod_times'].isna()
297554

298-
fcn = interpolate.interp1d(fp_data['Timestamp'][use_times].values, fp_data['bpod_times'][use_times].values,
299-
fill_value="extrapolate")
555+
fcn = scipy.interpolate.interp1d(
556+
fp_data['Timestamp'][use_times].values, fp_data['bpod_times'][use_times].values, fill_value="extrapolate")
300557

301558
ts = fcn(fp_data['Timestamp'].values)
302559

303560
return ts
304561

305562

306-
class FibrePhotometryPreprocess(PhotometryPreprocess):
563+
class FibrePhotometryPreprocess(base_tasks.DynamicTask):
564+
@property
565+
def signature(self):
566+
signature = {
567+
'input_files': [('_mcc_DAQdata.raw.tdms', self.device_collection, True),
568+
('_neurophotometrics_fpData.raw.pqt', self.device_collection, True)],
569+
'output_files': [('photometry.signal.pqt', 'alf/photometry', True)]
570+
}
571+
return signature
572+
573+
priority = 90
574+
level = 1
575+
576+
def __init__(self, session_path, regions=None, **kwargs):
577+
super().__init__(session_path, **kwargs)
578+
# Task collection (this needs to be specified in the task kwargs)
579+
self.collection = self.get_task_collection(kwargs.get('collection', None))
580+
self.device_collection = self.get_device_collection('photometry', device_collection='raw_photometry_data')
581+
self.regions = regions
307582

308583
def _run(self, **kwargs):
309584
_, out_files = FibrePhotometry(self.session_path, collection=self.device_collection).extract(

0 commit comments

Comments
 (0)