Skip to content

Commit 1ef3f1b

Browse files
authored
Merge pull request #43 from int-brain-lab/sp_tonotopic
SP Tonotopic Mapping
2 parents c9c613e + 78000e4 commit 1ef3f1b

File tree

11 files changed

+584
-7
lines changed

11 files changed

+584
-7
lines changed

.editorconfig

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# EditorConfig is awesome: https://editorconfig.org
2+
3+
# top-most EditorConfig file
4+
root = true
5+
6+
# General settings
7+
[*]
8+
end_of_line = lf
9+
insert_final_newline = true
10+
trim_trailing_whitespace = true
11+
max_line_length = 80
12+
charset = utf-8
13+
14+
# 4 space indentation
15+
[*.py]
16+
max_line_length = 130
17+
indent_style = space
18+
indent_size = 4
19+
20+
# Indentation override for rst files
21+
[*.rst]
22+
indent_style = space
23+
indent_size = 3
24+
25+
# Indentation override for rst and yaml files
26+
[*.{toml,yaml,qrc}]
27+
indent_style = space
28+
indent_size = 2

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,6 @@ dmypy.json
129129

130130
# Pyre type checker
131131
.pyre/
132+
133+
.idea
134+
uv.lock
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
attenuation.csv
Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
"""Tonotopic mapping task for IBL-Rig.
2+
3+
This module implements a passive tonotopic mapping protocol that plays a
4+
set of pure tones (and optionally white noise) at various levels via the
5+
Harp sound card, while emitting TTLs on BNC2 for synchronization. The
6+
sequence is organized as a Bpod state machine with up to 255 states per
7+
trial. After the session, the raw jsonable file is converted to a compact
8+
Parquet table containing the stimulus on/off events and stimulus
9+
parameters.
10+
11+
Key concepts:
12+
- Frequencies are generated on a log-spaced grid between freq_0 and
13+
freq_1, optionally prepended by a white-noise pseudo-frequency (-1).
14+
- For each frequency and level combination, a waveform is pre-rendered
15+
and uploaded to the Harp soundcard.
16+
- Attenuation corrections can be provided via attenuation.csv in the
17+
task directory; otherwise a zero-attenuation LUT is created.
18+
- Repetitions per trial are split to respect the 255-state limit.
19+
20+
The module also exposes a helper function `create_dataframe` to extract a
21+
Pandas DataFrame from a session's `_iblrig_taskData.raw.jsonable` file,
22+
keeping only the audio TTL channel and parsing the state names to recover
23+
stimulus parameters.
24+
"""
25+
26+
import logging
27+
import time
28+
from pathlib import Path
29+
from typing import cast
30+
31+
import numpy as np
32+
import pandas as pd
33+
from pydantic import FilePath, validate_call
34+
35+
from iblrig import sound
36+
from iblrig.base_choice_world import NTRIALS_INIT
37+
from iblrig.base_tasks import BaseSession, BpodMixin
38+
from iblrig.misc import get_task_arguments
39+
from iblrig.pydantic_definitions import TrialDataModel
40+
from iblrig.raw_data_loaders import bpod_session_data_to_dataframe
41+
from iblutil.io import jsonable
42+
from pybpodapi.state_machine import StateMachine
43+
44+
log = logging.getLogger('iblrig')
45+
46+
47+
class TonotopicMappingTrialData(TrialDataModel):
48+
"""Schema for per-trial metadata recorded by this task.
49+
50+
Attributes
51+
----------
52+
frequency_sequence : list[int]
53+
The sequence of frequencies (Hz; -1 encodes white noise) that
54+
were played within the trial, in the order they occurred.
55+
level_sequence : list[int]
56+
The corresponding sequence of level values (attenuation/gain in
57+
dB) used for each stimulus within the trial.
58+
"""
59+
60+
frequency_sequence: list[int]
61+
level_sequence: list[int]
62+
63+
64+
class Session(BpodMixin, BaseSession):
65+
"""Tonotopic mapping session orchestrated via Bpod and Harp.
66+
67+
This session pre-generates a grid of stimuli (frequency x level),
68+
uploads them to the Harp sound-card, and then steps through a Bpod
69+
state machine that triggers each stimulus with a fixed duration and a
70+
pause between stimuli. The order can be shuffled per trial.
71+
"""
72+
73+
protocol_name = 'samuel_tonotopicMapping'
74+
TrialDataModel = TonotopicMappingTrialData
75+
76+
parameters: np.ndarray = np.array([[], []])
77+
sequence: np.ndarray = np.array([])
78+
trial_num: int = -1
79+
80+
def __init__(self, *args, **kwargs):
81+
"""Initialize the session, stimuli, and repetition plan.
82+
83+
Steps:
84+
- Validate that a Harp sound card is used and can hold all
85+
waveforms (<= 29).
86+
- Build the frequency grid (log-spaced) and combine with levels.
87+
- Load or create an attenuation LUT.
88+
- Split desired repetitions into chunks that fit into 255 states.
89+
- Pre-render and register all stimuli with appropriate gains.
90+
"""
91+
super().__init__(*args, **kwargs)
92+
self.trials_table = self.TrialDataModel.preallocate_dataframe(NTRIALS_INIT)
93+
94+
# Hardware constraints: Harp output only and waveform count limit
95+
assert self.hardware_settings.device_sound.OUTPUT == 'harp', 'This task requires a Harp sound-card'
96+
assert self.task_params['n_freqs'] * len(self.task_params['levels']) <= 29, 'Harp only supports up to 29 waveforms'
97+
98+
# Define frequencies (log spaced from freq_0 to freq_1, rounded to nearest integer)
99+
frequencies = np.logspace(
100+
np.log10(self.task_params['freq_0']),
101+
np.log10(self.task_params['freq_1']),
102+
num=self.task_params['n_freqs'] - self.task_params['include_white_noise'],
103+
)
104+
frequencies = np.round(frequencies).astype(int)
105+
if self.task_params['include_white_noise']:
106+
# Use -1 as a sentinel for white noise to keep arrays numeric
107+
frequencies = np.insert(frequencies, 0, -1, axis=0)
108+
109+
# Get all parameter combinations (frequency x level)
110+
Session.parameters = np.array(np.meshgrid(frequencies, self.task_params['levels'])).T.reshape(-1, 2)
111+
112+
# Get LUT (or create new one based on frequencies) for corrective gains
113+
attenuation_file = self.get_task_directory().joinpath('attenuation.csv')
114+
if attenuation_file.exists():
115+
self.attenuation_lut = pd.read_csv(self.get_task_directory().joinpath('attenuation.csv'))
116+
else:
117+
self.attenuation_lut = pd.DataFrame({'frequency_hz': frequencies, 'attenuation_db': np.zeros(len(frequencies))})
118+
self.attenuation_lut.to_csv(attenuation_file, index=False)
119+
120+
# Calculate repetitions per state machine run (255 states max)
121+
self.repetitions = []
122+
max_reps_per_trial = 255 // self.n_stimuli
123+
reps_remaining = self.task_params['n_reps_per_stim']
124+
while reps_remaining > 0:
125+
self.repetitions.append(min(max_reps_per_trial, reps_remaining))
126+
reps_remaining -= self.repetitions[-1]
127+
128+
# Select channel configuration for playback. We mirror the default
129+
# so that sound is output on the opposite channel of DEFAULT.
130+
match self.hardware_settings.device_sound.DEFAULT_CHANNELS:
131+
case 'left':
132+
channels = 'right'
133+
case 'right':
134+
channels = 'left'
135+
case _:
136+
channels = 'stereo'
137+
138+
# Generate and register stimuli for Harp
139+
self.stimuli = []
140+
for stimulus_index in range(self.n_stimuli):
141+
frequency = self.parameters[stimulus_index][0]
142+
level = self.parameters[stimulus_index][1]
143+
tmp = sound.make_sound(
144+
rate=self.task_params['fs'],
145+
frequency=frequency,
146+
duration=self.task_params['d_sound'],
147+
amplitude=self.task_params['amplitude'],
148+
fade=self.task_params['d_ramp'],
149+
chans=channels,
150+
# Combine corrective gain (from LUT) with requested level
151+
gain_db=self.get_corrective_gain(frequency) + level,
152+
)
153+
self.stimuli.append(tmp)
154+
# Harp indexes start at 2 because 1 is reserved by Bpod
155+
self.harp_indices = [i for i in range(2, self.n_stimuli + 2)]
156+
157+
@property
158+
def n_stimuli(self):
159+
"""Total number of distinct stimuli (frequency x level)."""
160+
return self.parameters.shape[0]
161+
162+
@property
163+
def n_trials(self):
164+
"""Number of Bpod runs required given the 255-state constraint."""
165+
return len(self.repetitions)
166+
167+
def get_corrective_gain(self, frequency: int):
168+
"""Return corrective gain in dB for a given frequency."""
169+
return np.interp(frequency, self.attenuation_lut['frequency_hz'], self.attenuation_lut['attenuation_db'])
170+
171+
def start_mixin_sound(self):
172+
"""Upload waveforms to Harp and register Bpod output actions."""
173+
log.info(f'Pushing {len(self.parameters)} stimuli to Harp soundcard')
174+
sound.configure_sound_card(sounds=self.stimuli, indexes=self.harp_indices, sample_rate=self.task_params['fs'])
175+
176+
module = self.bpod.sound_card
177+
module_port = f'Serial{module.serial_port if module is not None else "3"}'
178+
for stimulus_idx, harp_idx in enumerate(self.harp_indices):
179+
bpod_message = [ord('P'), harp_idx]
180+
bpod_action = (module_port, self.bpod._define_message(self.bpod.sound_card, bpod_message))
181+
self.bpod.actions.update({f'stim_{stimulus_idx}': bpod_action})
182+
183+
# Soft code allows logging per-state when the state is entered
184+
self.bpod.softcode_handler_function = self.softcode_handler
185+
186+
def start_hardware(self):
187+
"""Start Bpod and sound-card mixins."""
188+
self.start_mixin_bpod()
189+
self.start_mixin_sound()
190+
191+
@staticmethod
192+
def get_state_name(state_idx: int):
193+
"""Return human-readable state name including parameters.
194+
195+
The name is formatted as: "{idx:03d}_{freq_label}_{gain}dB" where
196+
freq_label is e.g. "8000Hz" or "WN" for white noise. An extra
197+
"exit" state name is returned past the last stimulus index.
198+
"""
199+
if state_idx < len(Session.sequence):
200+
stimulus_idx = Session.sequence[state_idx]
201+
frequency = Session.parameters[stimulus_idx][0]
202+
gain = Session.parameters[stimulus_idx][1]
203+
return '{:03d}_{:s}_{:d}dB'.format(state_idx, f'{frequency:d}Hz' if frequency >= 0 else 'WN', gain)
204+
else:
205+
return 'exit'
206+
207+
@staticmethod
208+
def softcode_handler(softcode: int) -> None:
209+
"""Log information about the current state entry.
210+
211+
Parameters
212+
----------
213+
softcode : int
214+
One-based state index sent by the state machine.
215+
"""
216+
state_index = softcode - 1
217+
stimulus_index = Session.sequence[state_index]
218+
frequency = Session.parameters[stimulus_index][0]
219+
gain = Session.parameters[stimulus_index][1]
220+
n_states = len(Session.sequence)
221+
if frequency >= 0:
222+
log.info(f'- {state_index + 1:03d}/{n_states:03d}: {frequency:8d} Hz, {gain:3d} dB')
223+
else:
224+
log.info(f'- {state_index + 1:03d}/{n_states:03d}: white noise, {gain:3d} dB')
225+
226+
def get_state_machine(self, trial_number: int) -> StateMachine:
227+
"""Construct the Bpod state machine for a given trial.
228+
229+
The sequence of stimuli is repeated `repetitions[trial_number]`
230+
times and optionally shuffled deterministically using the trial
231+
number as seed.
232+
"""
233+
# generate sequence, optionally shuffled (seeded with trial number)
234+
Session.sequence = np.repeat(np.arange(self.n_stimuli), self.repetitions[trial_number])
235+
if self.task_params['shuffle']:
236+
np.random.seed(trial_number)
237+
np.random.shuffle(Session.sequence)
238+
239+
# build state machine
240+
sma = StateMachine(self.bpod)
241+
for state_idx, stimulus_idx in enumerate(self.sequence):
242+
sma.add_state(
243+
state_name=self.get_state_name(state_idx),
244+
state_timer=self.task_params['d_sound'] + self.task_params['d_pause'],
245+
output_actions=[self.bpod.actions[f'stim_{stimulus_idx}'], ('SoftCode', state_idx + 1)],
246+
state_change_conditions={'Tup': self.get_state_name(state_idx + 1)},
247+
)
248+
return sma
249+
250+
def _run(self):
251+
"""Run the session across all required Bpod state-machine runs."""
252+
for trial_number in range(self.n_trials):
253+
self.trial_num = trial_number
254+
255+
# run state machine
256+
log.info(f'Starting Trial #{trial_number} ({trial_number + 1}/{self.n_trials})')
257+
sma = self.get_state_machine(trial_number)
258+
self.bpod.send_state_machine(sma)
259+
self.bpod.run_state_machine(sma)
260+
261+
# handle pause event
262+
if self.paused and trial_number < (self.task_params.NTRIALS - 1):
263+
log.info(f'Pausing session inbetween trials #{trial_number} and #{trial_number + 1}')
264+
while self.paused and not self.stopped:
265+
time.sleep(1)
266+
if not self.stopped:
267+
log.info('Resuming session')
268+
269+
# save trial data: also store the exact sequence for QC and extraction
270+
self.trials_table.at[self.trial_num, 'frequency_sequence'] = self.parameters[self.sequence, 0]
271+
self.trials_table.at[self.trial_num, 'level_sequence'] = self.parameters[self.sequence, 1]
272+
bpod_data = self.bpod.session.current_trial.export()
273+
self.save_trial_data_to_json(bpod_data)
274+
275+
# handle stop event
276+
if self.stopped:
277+
log.info('Stopping session after trial #%d', trial_number)
278+
break
279+
280+
# convert data to parquet and remove jsonable file
281+
path_jsonable = cast(Path, self.paths['DATA_FILE_PATH'])
282+
path_parquet = path_jsonable.with_suffix('.pqt')
283+
data = create_dataframe(path_jsonable)
284+
data.to_parquet(path_parquet)
285+
assert path_parquet.exists()
286+
path_jsonable.unlink()
287+
288+
289+
@validate_call
290+
def create_dataframe(jsonable_file: FilePath) -> pd.DataFrame:
291+
"""Create a compact DataFrame of audio TTL events from a jsonable file.
292+
293+
This utility loads the raw task jsonable, keeps only the audio TTL
294+
channel (BNC2), and parses the Bpod state names to recover the
295+
stimulus index, frequency (or white noise), and attenuation. The
296+
output is suitable for downstream analysis and alignment.
297+
298+
Parameters
299+
----------
300+
jsonable_file : str | os.PathLike
301+
Path to a session's `_iblrig_taskData.raw.jsonable` file.
302+
303+
Returns
304+
-------
305+
pd.DataFrame
306+
Columns: Trial, Stimulus, Value, Frequency, Attenuation.
307+
308+
Raises
309+
------
310+
ValueError
311+
If the input file is not named `_iblrig_taskData.raw.jsonable` or
312+
if it doesn't contain audio TTLs on channel BNC2.
313+
"""
314+
315+
# check argument
316+
if jsonable_file.name != '_iblrig_taskData.raw.jsonable':
317+
raise ValueError('Input file must be named `_iblrig_taskData.raw.jsonable`')
318+
319+
# load data
320+
bpod_dicts = jsonable.load_task_jsonable(jsonable_file)[1]
321+
bpod_data = bpod_session_data_to_dataframe(bpod_dicts)
322+
323+
# restrict to audio TTL events
324+
output = bpod_data[bpod_data['Channel'].eq('BNC2')].copy()
325+
if len(output) == 0:
326+
raise ValueError('No audio TTLs found in the provided file')
327+
328+
# extract stimulus parameters from state names
329+
output[['Stimulus', 'Frequency', 'Attenuation']] = output['State'].str.extract(r'^(\d+)_(\d+|WN)[^-\d]+([-\d]+)dB$')
330+
output.replace({'Frequency': 'WN'}, '-1', inplace=True)
331+
output[['Stimulus', 'Frequency', 'Attenuation']] = output[['Stimulus', 'Frequency', 'Attenuation']].astype('Int64')
332+
333+
# remove / reorder columns
334+
return output[['Trial', 'Stimulus', 'Value', 'Frequency', 'Attenuation']]
335+
336+
337+
if __name__ == '__main__':
338+
kwargs = get_task_arguments()
339+
sess = Session(**kwargs)
340+
sess.run()
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
'freq_0': 2000 # lowest frequency to map (Hz)
2+
'freq_1': 16000 # highest frequency to map (Hz)
3+
'include_white_noise': true # whether to include a white noise stimulus
4+
'n_freqs': 10 # number of frequencies to map (including white noise)
5+
'levels': [0, -6, -12] # stimulus gain levels to test (dB)
6+
'n_reps_per_stim': 10 # total number of presentations for each stimulus (across trials)
7+
'd_pause': 0.01 # duration of silence inbetween sound pulses (s)
8+
'd_sound': 0.15 # duration of sound pulse (s)
9+
'd_ramp': 0.01 # duration of Hanning ramp (s)
10+
'amplitude': 0.05 # initial amplitude of all stimuli prior to dB attenuation
11+
'fs': 192000 # sampling rate (Hz) - Harp supports 96000 Hz and 192000 Hz
12+
'shuffle': true # whether to shuffle the order of frequencies or not
13+
'skip_attenuation': false # whether to skip gain correction (potentially useful for calibration)

projects/extraction_tasks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
from projects.nate_optoBiasedChoiceWorld import OptoTrialsNidq, OptoTrialsBpod
1010
from projects._sp_passiveVideo import PassiveVideoTimeline
1111
from projects.max_optoStaticTrainingChoiceWorld import PulsePalTrialsBpod
12+
from projects.samuel_tonotopicMapping import TonotopicMappingTimeline

0 commit comments

Comments
 (0)