|
| 1 | +"""Tonotopic mapping task for IBL-Rig. |
| 2 | +
|
| 3 | +This module implements a passive tonotopic mapping protocol that plays a |
| 4 | +set of pure tones (and optionally white noise) at various levels via the |
| 5 | +Harp sound card, while emitting TTLs on BNC2 for synchronization. The |
| 6 | +sequence is organized as a Bpod state machine with up to 255 states per |
| 7 | +trial. After the session, the raw jsonable file is converted to a compact |
| 8 | +Parquet table containing the stimulus on/off events and stimulus |
| 9 | +parameters. |
| 10 | +
|
| 11 | +Key concepts: |
| 12 | +- Frequencies are generated on a log-spaced grid between freq_0 and |
| 13 | + freq_1, optionally prepended by a white-noise pseudo-frequency (-1). |
| 14 | +- For each frequency and level combination, a waveform is pre-rendered |
| 15 | + and uploaded to the Harp soundcard. |
| 16 | +- Attenuation corrections can be provided via attenuation.csv in the |
| 17 | + task directory; otherwise a zero-attenuation LUT is created. |
| 18 | +- Repetitions per trial are split to respect the 255-state limit. |
| 19 | +
|
| 20 | +The module also exposes a helper function `create_dataframe` to extract a |
| 21 | +Pandas DataFrame from a session's `_iblrig_taskData.raw.jsonable` file, |
| 22 | +keeping only the audio TTL channel and parsing the state names to recover |
| 23 | +stimulus parameters. |
| 24 | +""" |
| 25 | + |
| 26 | +import logging |
| 27 | +import time |
| 28 | +from pathlib import Path |
| 29 | +from typing import cast |
| 30 | + |
| 31 | +import numpy as np |
| 32 | +import pandas as pd |
| 33 | +from pydantic import FilePath, validate_call |
| 34 | + |
| 35 | +from iblrig import sound |
| 36 | +from iblrig.base_choice_world import NTRIALS_INIT |
| 37 | +from iblrig.base_tasks import BaseSession, BpodMixin |
| 38 | +from iblrig.misc import get_task_arguments |
| 39 | +from iblrig.pydantic_definitions import TrialDataModel |
| 40 | +from iblrig.raw_data_loaders import bpod_session_data_to_dataframe |
| 41 | +from iblutil.io import jsonable |
| 42 | +from pybpodapi.state_machine import StateMachine |
| 43 | + |
| 44 | +log = logging.getLogger('iblrig') |
| 45 | + |
| 46 | + |
| 47 | +class TonotopicMappingTrialData(TrialDataModel): |
| 48 | + """Schema for per-trial metadata recorded by this task. |
| 49 | +
|
| 50 | + Attributes |
| 51 | + ---------- |
| 52 | + frequency_sequence : list[int] |
| 53 | + The sequence of frequencies (Hz; -1 encodes white noise) that |
| 54 | + were played within the trial, in the order they occurred. |
| 55 | + level_sequence : list[int] |
| 56 | + The corresponding sequence of level values (attenuation/gain in |
| 57 | + dB) used for each stimulus within the trial. |
| 58 | + """ |
| 59 | + |
| 60 | + frequency_sequence: list[int] |
| 61 | + level_sequence: list[int] |
| 62 | + |
| 63 | + |
| 64 | +class Session(BpodMixin, BaseSession): |
| 65 | + """Tonotopic mapping session orchestrated via Bpod and Harp. |
| 66 | +
|
| 67 | + This session pre-generates a grid of stimuli (frequency x level), |
| 68 | + uploads them to the Harp sound-card, and then steps through a Bpod |
| 69 | + state machine that triggers each stimulus with a fixed duration and a |
| 70 | + pause between stimuli. The order can be shuffled per trial. |
| 71 | + """ |
| 72 | + |
| 73 | + protocol_name = 'samuel_tonotopicMapping' |
| 74 | + TrialDataModel = TonotopicMappingTrialData |
| 75 | + |
| 76 | + parameters: np.ndarray = np.array([[], []]) |
| 77 | + sequence: np.ndarray = np.array([]) |
| 78 | + trial_num: int = -1 |
| 79 | + |
| 80 | + def __init__(self, *args, **kwargs): |
| 81 | + """Initialize the session, stimuli, and repetition plan. |
| 82 | +
|
| 83 | + Steps: |
| 84 | + - Validate that a Harp sound card is used and can hold all |
| 85 | + waveforms (<= 29). |
| 86 | + - Build the frequency grid (log-spaced) and combine with levels. |
| 87 | + - Load or create an attenuation LUT. |
| 88 | + - Split desired repetitions into chunks that fit into 255 states. |
| 89 | + - Pre-render and register all stimuli with appropriate gains. |
| 90 | + """ |
| 91 | + super().__init__(*args, **kwargs) |
| 92 | + self.trials_table = self.TrialDataModel.preallocate_dataframe(NTRIALS_INIT) |
| 93 | + |
| 94 | + # Hardware constraints: Harp output only and waveform count limit |
| 95 | + assert self.hardware_settings.device_sound.OUTPUT == 'harp', 'This task requires a Harp sound-card' |
| 96 | + assert self.task_params['n_freqs'] * len(self.task_params['levels']) <= 29, 'Harp only supports up to 29 waveforms' |
| 97 | + |
| 98 | + # Define frequencies (log spaced from freq_0 to freq_1, rounded to nearest integer) |
| 99 | + frequencies = np.logspace( |
| 100 | + np.log10(self.task_params['freq_0']), |
| 101 | + np.log10(self.task_params['freq_1']), |
| 102 | + num=self.task_params['n_freqs'] - self.task_params['include_white_noise'], |
| 103 | + ) |
| 104 | + frequencies = np.round(frequencies).astype(int) |
| 105 | + if self.task_params['include_white_noise']: |
| 106 | + # Use -1 as a sentinel for white noise to keep arrays numeric |
| 107 | + frequencies = np.insert(frequencies, 0, -1, axis=0) |
| 108 | + |
| 109 | + # Get all parameter combinations (frequency x level) |
| 110 | + Session.parameters = np.array(np.meshgrid(frequencies, self.task_params['levels'])).T.reshape(-1, 2) |
| 111 | + |
| 112 | + # Get LUT (or create new one based on frequencies) for corrective gains |
| 113 | + attenuation_file = self.get_task_directory().joinpath('attenuation.csv') |
| 114 | + if attenuation_file.exists(): |
| 115 | + self.attenuation_lut = pd.read_csv(self.get_task_directory().joinpath('attenuation.csv')) |
| 116 | + else: |
| 117 | + self.attenuation_lut = pd.DataFrame({'frequency_hz': frequencies, 'attenuation_db': np.zeros(len(frequencies))}) |
| 118 | + self.attenuation_lut.to_csv(attenuation_file, index=False) |
| 119 | + |
| 120 | + # Calculate repetitions per state machine run (255 states max) |
| 121 | + self.repetitions = [] |
| 122 | + max_reps_per_trial = 255 // self.n_stimuli |
| 123 | + reps_remaining = self.task_params['n_reps_per_stim'] |
| 124 | + while reps_remaining > 0: |
| 125 | + self.repetitions.append(min(max_reps_per_trial, reps_remaining)) |
| 126 | + reps_remaining -= self.repetitions[-1] |
| 127 | + |
| 128 | + # Select channel configuration for playback. We mirror the default |
| 129 | + # so that sound is output on the opposite channel of DEFAULT. |
| 130 | + match self.hardware_settings.device_sound.DEFAULT_CHANNELS: |
| 131 | + case 'left': |
| 132 | + channels = 'right' |
| 133 | + case 'right': |
| 134 | + channels = 'left' |
| 135 | + case _: |
| 136 | + channels = 'stereo' |
| 137 | + |
| 138 | + # Generate and register stimuli for Harp |
| 139 | + self.stimuli = [] |
| 140 | + for stimulus_index in range(self.n_stimuli): |
| 141 | + frequency = self.parameters[stimulus_index][0] |
| 142 | + level = self.parameters[stimulus_index][1] |
| 143 | + tmp = sound.make_sound( |
| 144 | + rate=self.task_params['fs'], |
| 145 | + frequency=frequency, |
| 146 | + duration=self.task_params['d_sound'], |
| 147 | + amplitude=self.task_params['amplitude'], |
| 148 | + fade=self.task_params['d_ramp'], |
| 149 | + chans=channels, |
| 150 | + # Combine corrective gain (from LUT) with requested level |
| 151 | + gain_db=self.get_corrective_gain(frequency) + level, |
| 152 | + ) |
| 153 | + self.stimuli.append(tmp) |
| 154 | + # Harp indexes start at 2 because 1 is reserved by Bpod |
| 155 | + self.harp_indices = [i for i in range(2, self.n_stimuli + 2)] |
| 156 | + |
| 157 | + @property |
| 158 | + def n_stimuli(self): |
| 159 | + """Total number of distinct stimuli (frequency x level).""" |
| 160 | + return self.parameters.shape[0] |
| 161 | + |
| 162 | + @property |
| 163 | + def n_trials(self): |
| 164 | + """Number of Bpod runs required given the 255-state constraint.""" |
| 165 | + return len(self.repetitions) |
| 166 | + |
| 167 | + def get_corrective_gain(self, frequency: int): |
| 168 | + """Return corrective gain in dB for a given frequency.""" |
| 169 | + return np.interp(frequency, self.attenuation_lut['frequency_hz'], self.attenuation_lut['attenuation_db']) |
| 170 | + |
| 171 | + def start_mixin_sound(self): |
| 172 | + """Upload waveforms to Harp and register Bpod output actions.""" |
| 173 | + log.info(f'Pushing {len(self.parameters)} stimuli to Harp soundcard') |
| 174 | + sound.configure_sound_card(sounds=self.stimuli, indexes=self.harp_indices, sample_rate=self.task_params['fs']) |
| 175 | + |
| 176 | + module = self.bpod.sound_card |
| 177 | + module_port = f'Serial{module.serial_port if module is not None else "3"}' |
| 178 | + for stimulus_idx, harp_idx in enumerate(self.harp_indices): |
| 179 | + bpod_message = [ord('P'), harp_idx] |
| 180 | + bpod_action = (module_port, self.bpod._define_message(self.bpod.sound_card, bpod_message)) |
| 181 | + self.bpod.actions.update({f'stim_{stimulus_idx}': bpod_action}) |
| 182 | + |
| 183 | + # Soft code allows logging per-state when the state is entered |
| 184 | + self.bpod.softcode_handler_function = self.softcode_handler |
| 185 | + |
| 186 | + def start_hardware(self): |
| 187 | + """Start Bpod and sound-card mixins.""" |
| 188 | + self.start_mixin_bpod() |
| 189 | + self.start_mixin_sound() |
| 190 | + |
| 191 | + @staticmethod |
| 192 | + def get_state_name(state_idx: int): |
| 193 | + """Return human-readable state name including parameters. |
| 194 | +
|
| 195 | + The name is formatted as: "{idx:03d}_{freq_label}_{gain}dB" where |
| 196 | + freq_label is e.g. "8000Hz" or "WN" for white noise. An extra |
| 197 | + "exit" state name is returned past the last stimulus index. |
| 198 | + """ |
| 199 | + if state_idx < len(Session.sequence): |
| 200 | + stimulus_idx = Session.sequence[state_idx] |
| 201 | + frequency = Session.parameters[stimulus_idx][0] |
| 202 | + gain = Session.parameters[stimulus_idx][1] |
| 203 | + return '{:03d}_{:s}_{:d}dB'.format(state_idx, f'{frequency:d}Hz' if frequency >= 0 else 'WN', gain) |
| 204 | + else: |
| 205 | + return 'exit' |
| 206 | + |
| 207 | + @staticmethod |
| 208 | + def softcode_handler(softcode: int) -> None: |
| 209 | + """Log information about the current state entry. |
| 210 | +
|
| 211 | + Parameters |
| 212 | + ---------- |
| 213 | + softcode : int |
| 214 | + One-based state index sent by the state machine. |
| 215 | + """ |
| 216 | + state_index = softcode - 1 |
| 217 | + stimulus_index = Session.sequence[state_index] |
| 218 | + frequency = Session.parameters[stimulus_index][0] |
| 219 | + gain = Session.parameters[stimulus_index][1] |
| 220 | + n_states = len(Session.sequence) |
| 221 | + if frequency >= 0: |
| 222 | + log.info(f'- {state_index + 1:03d}/{n_states:03d}: {frequency:8d} Hz, {gain:3d} dB') |
| 223 | + else: |
| 224 | + log.info(f'- {state_index + 1:03d}/{n_states:03d}: white noise, {gain:3d} dB') |
| 225 | + |
| 226 | + def get_state_machine(self, trial_number: int) -> StateMachine: |
| 227 | + """Construct the Bpod state machine for a given trial. |
| 228 | +
|
| 229 | + The sequence of stimuli is repeated `repetitions[trial_number]` |
| 230 | + times and optionally shuffled deterministically using the trial |
| 231 | + number as seed. |
| 232 | + """ |
| 233 | + # generate sequence, optionally shuffled (seeded with trial number) |
| 234 | + Session.sequence = np.repeat(np.arange(self.n_stimuli), self.repetitions[trial_number]) |
| 235 | + if self.task_params['shuffle']: |
| 236 | + np.random.seed(trial_number) |
| 237 | + np.random.shuffle(Session.sequence) |
| 238 | + |
| 239 | + # build state machine |
| 240 | + sma = StateMachine(self.bpod) |
| 241 | + for state_idx, stimulus_idx in enumerate(self.sequence): |
| 242 | + sma.add_state( |
| 243 | + state_name=self.get_state_name(state_idx), |
| 244 | + state_timer=self.task_params['d_sound'] + self.task_params['d_pause'], |
| 245 | + output_actions=[self.bpod.actions[f'stim_{stimulus_idx}'], ('SoftCode', state_idx + 1)], |
| 246 | + state_change_conditions={'Tup': self.get_state_name(state_idx + 1)}, |
| 247 | + ) |
| 248 | + return sma |
| 249 | + |
| 250 | + def _run(self): |
| 251 | + """Run the session across all required Bpod state-machine runs.""" |
| 252 | + for trial_number in range(self.n_trials): |
| 253 | + self.trial_num = trial_number |
| 254 | + |
| 255 | + # run state machine |
| 256 | + log.info(f'Starting Trial #{trial_number} ({trial_number + 1}/{self.n_trials})') |
| 257 | + sma = self.get_state_machine(trial_number) |
| 258 | + self.bpod.send_state_machine(sma) |
| 259 | + self.bpod.run_state_machine(sma) |
| 260 | + |
| 261 | + # handle pause event |
| 262 | + if self.paused and trial_number < (self.task_params.NTRIALS - 1): |
| 263 | + log.info(f'Pausing session inbetween trials #{trial_number} and #{trial_number + 1}') |
| 264 | + while self.paused and not self.stopped: |
| 265 | + time.sleep(1) |
| 266 | + if not self.stopped: |
| 267 | + log.info('Resuming session') |
| 268 | + |
| 269 | + # save trial data: also store the exact sequence for QC and extraction |
| 270 | + self.trials_table.at[self.trial_num, 'frequency_sequence'] = self.parameters[self.sequence, 0] |
| 271 | + self.trials_table.at[self.trial_num, 'level_sequence'] = self.parameters[self.sequence, 1] |
| 272 | + bpod_data = self.bpod.session.current_trial.export() |
| 273 | + self.save_trial_data_to_json(bpod_data) |
| 274 | + |
| 275 | + # handle stop event |
| 276 | + if self.stopped: |
| 277 | + log.info('Stopping session after trial #%d', trial_number) |
| 278 | + break |
| 279 | + |
| 280 | + # convert data to parquet and remove jsonable file |
| 281 | + path_jsonable = cast(Path, self.paths['DATA_FILE_PATH']) |
| 282 | + path_parquet = path_jsonable.with_suffix('.pqt') |
| 283 | + data = create_dataframe(path_jsonable) |
| 284 | + data.to_parquet(path_parquet) |
| 285 | + assert path_parquet.exists() |
| 286 | + path_jsonable.unlink() |
| 287 | + |
| 288 | + |
| 289 | +@validate_call |
| 290 | +def create_dataframe(jsonable_file: FilePath) -> pd.DataFrame: |
| 291 | + """Create a compact DataFrame of audio TTL events from a jsonable file. |
| 292 | +
|
| 293 | + This utility loads the raw task jsonable, keeps only the audio TTL |
| 294 | + channel (BNC2), and parses the Bpod state names to recover the |
| 295 | + stimulus index, frequency (or white noise), and attenuation. The |
| 296 | + output is suitable for downstream analysis and alignment. |
| 297 | +
|
| 298 | + Parameters |
| 299 | + ---------- |
| 300 | + jsonable_file : str | os.PathLike |
| 301 | + Path to a session's `_iblrig_taskData.raw.jsonable` file. |
| 302 | +
|
| 303 | + Returns |
| 304 | + ------- |
| 305 | + pd.DataFrame |
| 306 | + Columns: Trial, Stimulus, Value, Frequency, Attenuation. |
| 307 | +
|
| 308 | + Raises |
| 309 | + ------ |
| 310 | + ValueError |
| 311 | + If the input file is not named `_iblrig_taskData.raw.jsonable` or |
| 312 | + if it doesn't contain audio TTLs on channel BNC2. |
| 313 | + """ |
| 314 | + |
| 315 | + # check argument |
| 316 | + if jsonable_file.name != '_iblrig_taskData.raw.jsonable': |
| 317 | + raise ValueError('Input file must be named `_iblrig_taskData.raw.jsonable`') |
| 318 | + |
| 319 | + # load data |
| 320 | + bpod_dicts = jsonable.load_task_jsonable(jsonable_file)[1] |
| 321 | + bpod_data = bpod_session_data_to_dataframe(bpod_dicts) |
| 322 | + |
| 323 | + # restrict to audio TTL events |
| 324 | + output = bpod_data[bpod_data['Channel'].eq('BNC2')].copy() |
| 325 | + if len(output) == 0: |
| 326 | + raise ValueError('No audio TTLs found in the provided file') |
| 327 | + |
| 328 | + # extract stimulus parameters from state names |
| 329 | + output[['Stimulus', 'Frequency', 'Attenuation']] = output['State'].str.extract(r'^(\d+)_(\d+|WN)[^-\d]+([-\d]+)dB$') |
| 330 | + output.replace({'Frequency': 'WN'}, '-1', inplace=True) |
| 331 | + output[['Stimulus', 'Frequency', 'Attenuation']] = output[['Stimulus', 'Frequency', 'Attenuation']].astype('Int64') |
| 332 | + |
| 333 | + # remove / reorder columns |
| 334 | + return output[['Trial', 'Stimulus', 'Value', 'Frequency', 'Attenuation']] |
| 335 | + |
| 336 | + |
| 337 | +if __name__ == '__main__': |
| 338 | + kwargs = get_task_arguments() |
| 339 | + sess = Session(**kwargs) |
| 340 | + sess.run() |
0 commit comments