Skip to content

Commit 9c59b89

Browse files
mdmelink1o0
andauthored
Add custom opto task and extraction for U19 project (#15)
* add staticTrainingChoiceWorld * fix folder name * fix duplicated stimulus * Add pulsepal mixin * add task * skeleton extraction code for opto task * fix import and abstract property bugs * laser time is a parameter * todos * small tweaks * Task running. Need to change on state and fix stim rampdown * add GlobalTimer to control the rampdown of opto * hacky fix for first trial * track opto on time * update default params: add 0.5 contrast * extract opto intervals * docs * extractor and qc * qc logic * extractor map typo * map non-opto task * update non-opto task mapping * basic LED calibration * make punish timeout a parameter * proper path * Update pyproject.toml --------- Co-authored-by: k1o0 <[email protected]>
1 parent 1a498dd commit 9c59b89

File tree

11 files changed

+659
-1
lines changed

11 files changed

+659
-1
lines changed
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import logging
2+
import sys
3+
from typing import Literal
4+
from abc import ABC, abstractmethod
5+
import numpy as np
6+
7+
from iblrig.base_choice_world import SOFTCODE
8+
from pybpodapi.protocol import StateMachine, Bpod
9+
from pypulsepal import PulsePalObject
10+
from iblrig.base_tasks import BaseSession
11+
12+
log = logging.getLogger('iblrig.task')
13+
14+
SOFTCODE_FIRE_PULSEPAL = max(SOFTCODE).value + 1
15+
SOFTCODE_STOP_PULSEPAL = max(SOFTCODE).value + 2
16+
V_MAX = 5
17+
18+
19+
class PulsePalStateMachine(StateMachine):
20+
"""
21+
This class adds:
22+
1. Hardware or sofware triggering of optogenetic stimulation via a PulsePal (or BPod Analog Output Module)
23+
EITHER
24+
- adds soft-codes for starting and stopping the opto stim
25+
OR
26+
- sets up a TTL to hardware trigger the PulsePal
27+
2. (not yet implemented!!!) sets up a TTL channel for recording opto stim times from the PulsePal
28+
"""
29+
# TODO: define the TTL channel for recording opto stim times?
30+
def __init__(
31+
self,
32+
bpod,
33+
trigger_type: Literal['soft', 'hardware'] = 'soft',
34+
is_opto_stimulation=False,
35+
states_opto_ttls=None,
36+
states_opto_stop=None,
37+
opto_t_max_seconds=None,
38+
):
39+
super().__init__(bpod)
40+
self.trigger_type = trigger_type
41+
self.is_opto_stimulation = is_opto_stimulation
42+
self.states_opto_ttls = states_opto_ttls or []
43+
self.states_opto_stop = states_opto_stop or []
44+
45+
# Set global timer 1 for T_MAX
46+
self.set_global_timer(timer_id=1, timer_duration=opto_t_max_seconds)
47+
48+
def add_state(self, **kwargs):
49+
if self.is_opto_stimulation:
50+
if kwargs['state_name'] in self.states_opto_ttls:
51+
if self.trigger_type == 'soft':
52+
kwargs['output_actions'] += [('SoftCode', SOFTCODE_FIRE_PULSEPAL),]
53+
elif self.trigger_type == 'hardware':
54+
kwargs['output_actions'] += [('BNC2', 255),]
55+
kwargs['output_actions'] += [(Bpod.OutputChannels.GlobalTimerTrig, 1)] # start the global timer when the opto stim comes on
56+
elif kwargs['state_name'] in self.states_opto_stop:
57+
if self.trigger_type == 'soft':
58+
kwargs['output_actions'] += [('SoftCode', SOFTCODE_STOP_PULSEPAL),]
59+
elif self.trigger_type == 'hardware':
60+
kwargs['output_actions'] += [('BNC2', 0),]
61+
62+
super().add_state(**kwargs)
63+
64+
class PulsePalMixin(ABC):
65+
"""
66+
A mixin class that adds optogenetic stimulation capabilities to a task via the
67+
PulsePal module (or a Analog Output module running PulsePal firmware). It is used
68+
in conjunction with the PulsePalStateMachine class rather than the StateMachine class.
69+
70+
The user must define the arm_opto_stim method to define the parameters for optogenetic stimulation.
71+
PulsePalMixin supports soft-code triggering via the start_opto_stim and stop_opto_stim methods.
72+
Hardware triggering is also supported by defining trigger channels in the arm_opto_stim method.
73+
74+
The opto stim is currently hard-coded on output channel 1.
75+
A TTL pulse is hard-coded on output channel 2 for accurately recording trigger times. This TTL
76+
will rise when the opto stim starts and fall when it stops, thus accurately recording software trigger times.
77+
"""
78+
79+
def start_opto_hardware(self):
80+
self.pulsepal_connection = PulsePalObject('COM13') # TODO: get port from hardware params
81+
log.warning('Connected to PulsePal')
82+
# TODO: get the calibration value for this specific cannula
83+
#super().start_hardware() # TODO: move this out
84+
85+
# add the softcodes for the PulsePal
86+
soft_code_dict = self.bpod.softcodes
87+
soft_code_dict.update({SOFTCODE_STOP_PULSEPAL: self.stop_opto_stim})
88+
soft_code_dict.update({SOFTCODE_FIRE_PULSEPAL: self.start_opto_stim})
89+
self.bpod.register_softcodes(soft_code_dict)
90+
91+
@abstractmethod
92+
def arm_opto_stim(self, ttl_output_channel):
93+
raise NotImplementedError("User must define the stimulus and trigger type to deliver with pulsepal")
94+
# Define the pulse sequence and load it to the desired output channel here
95+
# This method should not fire the pulse train, that is handled by start_opto_stim() (soft-trigger) or a hardware trigger
96+
# See https://github.com/sanworks/PulsePal/blob/master/Python/Python3/PulsePalExample.py for examples
97+
# you should also define the max_stim_seconds property here to set the maximum duration of the pulse train
98+
99+
##############################
100+
# Example code to define a sine wave lasting 5 seconds
101+
voltages = list(range(0, 1000))
102+
for i in voltages:
103+
voltages[i] = math.sin(voltages[i]/float(10))*10 # Set 1,000 voltages to create a 20V peak-to-peak sine waveform
104+
times = np.linspace(0, 5, len(voltages)) # Create a time vector for the waveform
105+
self.stim_length_seconds = times[-1] # it is essential to get this property right so that the TTL for recording stim pulses is correcty defined
106+
self.pulsepal_connection.sendCustomPulseTrain(1, times, voltages)
107+
self.pulsepal_connection.programOutputChannelParam('customTrainID', 1, 1)
108+
##############################
109+
110+
@property
111+
@abstractmethod
112+
def stim_length_seconds():
113+
# this should be set within the arm_opto_stim method
114+
pass
115+
116+
def arm_ttl_stim(self):
117+
# a TTL pulse from channel 2 that rises when the opto stim starts and falls when it stops
118+
log.warning('Arming TTL signal')
119+
self.pulsepal_connection.programOutputChannelParam('phase1Duration', 2, self.stim_length_seconds)
120+
self.pulsepal_connection.sendCustomPulseTrain(2, [0,], [V_MAX,])
121+
self.pulsepal_connection.programOutputChannelParam('customTrainID', 2, 2)
122+
123+
def start_opto_stim(self):
124+
self.pulsepal_connection.triggerOutputChannels(1, 1, 0, 0)
125+
log.warning('Started opto stim')
126+
127+
def stop_opto_stim(self):
128+
# this will stop the pulse train instantly (and the corresponding TTL pulse)
129+
# To avoid rebound spiking in the case of GtACR, a ramp down is recommended
130+
self.pulsepal_connection.abortPulseTrains()
131+
132+
def compute_vmax_from_calibration(self, calibration_value):
133+
# TODO: implement this method to convert the calibration value to a voltage for the opto stim
134+
pass
135+
136+
def __del__(self):
137+
del self.pulsepal_connection

iblrig_custom_tasks/max_optoStaticTrainingChoiceWorld/__init__.py

Whitespace-only changes.
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
"""
2+
This task is a replica of max_staticTrainingChoiceWorld with the addition of optogenetic stimulation
3+
An `opto_stimulation` column is added to the trials_table, which is a boolean array of length NTRIALS_INIT
4+
The PROBABILITY_OPTO_STIMULATION parameter is used to determine the probability of optogenetic stimulation
5+
for each trial
6+
7+
Additionally the state machine is modified to add output TTLs for optogenetic stimulation
8+
"""
9+
10+
import logging
11+
import random
12+
import sys
13+
from importlib.util import find_spec
14+
from pathlib import Path
15+
from typing import Literal
16+
import pandas as pd
17+
18+
import numpy as np
19+
import yaml
20+
import time
21+
22+
import iblrig
23+
from iblrig.base_choice_world import SOFTCODE
24+
from pybpodapi.protocol import StateMachine
25+
from iblrig_custom_tasks.max_staticTrainingChoiceWorld.task import Session as StaticTrainingChoiceSession
26+
from iblrig_custom_tasks.max_optoStaticTrainingChoiceWorld.PulsePal import PulsePalMixin, PulsePalStateMachine
27+
28+
stim_location_history = []
29+
30+
log = logging.getLogger('iblrig.task')
31+
32+
NTRIALS_INIT = 2000
33+
SOFTCODE_FIRE_LED = max(SOFTCODE).value + 1
34+
SOFTCODE_RAMP_DOWN_LED = max(SOFTCODE).value + 2
35+
RAMP_SECONDS = .25 # time to ramp down the opto stim # TODO: make this a parameter
36+
LED_V_MAX = 5 # maximum voltage for LED control # TODO: make this a parameter
37+
38+
# read defaults from task_parameters.yaml
39+
with open(Path(__file__).parent.joinpath('task_parameters.yaml')) as f:
40+
DEFAULTS = yaml.safe_load(f)
41+
42+
class Session(StaticTrainingChoiceSession, PulsePalMixin):
43+
protocol_name = 'max_optoStaticTrainingChoiceWorld'
44+
extractor_tasks = ['PulsePalTrials']
45+
46+
def __init__(
47+
self,
48+
*args,
49+
probability_opto_stim: float = DEFAULTS['PROBABILITY_OPTO_STIM'],
50+
opto_ttl_states: list[str] = DEFAULTS['OPTO_TTL_STATES'],
51+
opto_stop_states: list[str] = DEFAULTS['OPTO_STOP_STATES'],
52+
max_laser_time: float = DEFAULTS['MAX_LASER_TIME'],
53+
estimated_led_power_mW: float = DEFAULTS['ESTIMATED_LED_POWER_MW'],
54+
**kwargs,
55+
):
56+
super().__init__(*args, **kwargs)
57+
self.task_params['OPTO_TTL_STATES'] = opto_ttl_states
58+
self.task_params['OPTO_STOP_STATES'] = opto_stop_states
59+
self.task_params['PROBABILITY_OPTO_STIM'] = probability_opto_stim
60+
self.task_params['MAX_LASER_TIME'] = max_laser_time
61+
self.task_params['LED_POWER'] = estimated_led_power_mW
62+
# generates the opto stimulation for each trial
63+
opto = np.random.choice(
64+
[0, 1],
65+
p=[1 - probability_opto_stim, probability_opto_stim],
66+
size=NTRIALS_INIT,
67+
).astype(bool)
68+
69+
opto[0] = False
70+
self.trials_table['opto_stimulation'] = opto
71+
72+
# get the calibration values for the LED
73+
# TODO: do a calibration curve instead
74+
dat = pd.read_csv(r'Y:/opto_fiber_calibration_values.csv')
75+
l_cannula = f'{kwargs["subject"]}L' #TODO: where is SUBJECT defined?
76+
r_cannula = f'{kwargs["subject"]}R'
77+
l_cable = 0
78+
r_cable = 1
79+
l_cal_power = dat[(dat['Cannula'] == l_cannula) & (dat['cable_ID'] == l_cable)].cable_power.values[0]
80+
r_cal_power = dat[(dat['Cannula'] == r_cannula) & (dat['cable_ID'] == r_cable)].cable_power.values[0]
81+
82+
mean_cal_power = np.mean([l_cal_power, r_cal_power])
83+
vmax = LED_V_MAX * self.task_params['LED_POWER'] / mean_cal_power
84+
log.warning(f'Using VMAX: {vmax}V for target LED power {self.task_params["LED_POWER"]}mW')
85+
self.task_params['VMAX_LED'] = vmax
86+
87+
def _instantiate_state_machine(self, trial_number=None):
88+
"""
89+
We override this using the custom class PulsePalStateMachine that appends TTLs for optogenetic stimulation where needed
90+
:param trial_number:
91+
:return:
92+
"""
93+
# PWM1 is the LED OUTPUT for port interface board
94+
# Input is PortIn1
95+
# TODO: enable input port?
96+
log.warning('Instantiating state machine')
97+
is_opto_stimulation = self.trials_table.at[trial_number, 'opto_stimulation']
98+
if is_opto_stimulation:
99+
self.arm_opto_stim()
100+
self.arm_ttl_stim()
101+
return PulsePalStateMachine(
102+
self.bpod,
103+
trigger_type='soft', # software trigger
104+
is_opto_stimulation=is_opto_stimulation,
105+
states_opto_ttls=self.task_params['OPTO_TTL_STATES'],
106+
states_opto_stop=self.task_params['OPTO_STOP_STATES'],
107+
opto_t_max_seconds=self.task_params['MAX_LASER_TIME'],
108+
)
109+
110+
def arm_opto_stim(self):
111+
# define a contant offset voltage with a ramp down at the end to avoid rebound excitation
112+
log.warning('Arming opto stim')
113+
ramp = np.linspace(self.task_params['VMAX_LED'], 0, 1000) # SET POWER
114+
t = np.linspace(0, RAMP_SECONDS, 1000)
115+
v = np.concatenate((np.array([self.task_params['VMAX_LED']]), ramp)) # SET POWER
116+
t = np.concatenate((np.array([0]), t + self.task_params['MAX_LASER_TIME']))
117+
118+
self.pulsepal_connection.programOutputChannelParam('phase1Duration', 1, self.task_params['MAX_LASER_TIME'])
119+
self.pulsepal_connection.sendCustomPulseTrain(1, t, v)
120+
self.pulsepal_connection.programOutputChannelParam('customTrainID', 1, 1)
121+
122+
def start_opto_stim(self):
123+
super().start_opto_stim()
124+
self.opto_start_time = time.time()
125+
126+
@property
127+
def stim_length_seconds(self):
128+
return self.task_params['MAX_LASER_TIME']
129+
130+
def stop_opto_stim(self):
131+
if time.time() - self.opto_start_time >= self.task_params['MAX_LASER_TIME']:
132+
# the LED should have turned off by now, we don't need to force the ramp down
133+
log.warning('Stopped opto stim - hit opto timeout')
134+
return
135+
136+
# we will modify this function to ramp down the opto stim rather than abruptly stopping it
137+
# send instructions to set the TTL back to 0
138+
self.pulsepal_connection.programOutputChannelParam('phase1Duration', 2, self.task_params['MAX_LASER_TIME'])
139+
self.pulsepal_connection.sendCustomPulseTrain(2, [0,], [0,])
140+
self.pulsepal_connection.programOutputChannelParam('customTrainID', 2, 2)
141+
142+
# send instructions to ramp the opto stim down to 0
143+
v = np.linspace(self.task_params['VMAX_LED'], 0, 1000)
144+
t = np.linspace(0, RAMP_SECONDS, 1000)
145+
self.pulsepal_connection.programOutputChannelParam('phase1Duration', 1, self.task_params['MAX_LASER_TIME'])
146+
self.pulsepal_connection.sendCustomPulseTrain(1, t, v)
147+
self.pulsepal_connection.programOutputChannelParam('customTrainID', 1, 1)
148+
149+
# trigger these instructions
150+
self.pulsepal_connection.triggerOutputChannels(1, 1, 0, 0)
151+
log.warning('Stopped opto stim - hit a stop opto state')
152+
153+
def start_hardware(self):
154+
super().start_hardware()
155+
super().start_opto_hardware()
156+
157+
158+
@staticmethod
159+
def extra_parser():
160+
""":return: argparse.parser()"""
161+
parser = super(Session, Session).extra_parser()
162+
parser.add_argument(
163+
'--probability_opto_stim',
164+
option_strings=['--probability_opto_stim'],
165+
dest='probability_opto_stim',
166+
default=DEFAULTS['PROBABILITY_OPTO_STIM'],
167+
type=float,
168+
help=f'probability of opto-genetic stimulation (default: {DEFAULTS["PROBABILITY_OPTO_STIM"]})',
169+
)
170+
171+
parser.add_argument(
172+
'--opto_ttl_states',
173+
option_strings=['--opto_ttl_states'],
174+
dest='opto_ttl_states',
175+
default=DEFAULTS['OPTO_TTL_STATES'],
176+
nargs='+',
177+
type=str,
178+
help='list of the state machine states where opto stim should be delivered',
179+
)
180+
parser.add_argument(
181+
'--opto_stop_states',
182+
option_strings=['--opto_stop_states'],
183+
dest='opto_stop_states',
184+
default=DEFAULTS['OPTO_STOP_STATES'],
185+
nargs='+',
186+
type=str,
187+
help='list of the state machine states where opto stim should be stopped',
188+
)
189+
parser.add_argument(
190+
'--max_laser_time',
191+
option_strings=['--max_laser_time'],
192+
dest='max_laser_time',
193+
default=DEFAULTS['MAX_LASER_TIME'],
194+
type=float,
195+
help='Maximum laser duration in seconds',
196+
)
197+
parser.add_argument(
198+
'--estimated_led_power_mW',
199+
option_strings=['--estimated_led_power_mW'],
200+
dest='estimated_led_power_mW',
201+
default=DEFAULTS['ESTIMATED_LED_POWER_MW'],
202+
type=float,
203+
help='The estimated LED power in mW. Computed from a calibration curve'
204+
)
205+
206+
return parser
207+
208+
209+
if __name__ == '__main__': # pragma: no cover
210+
kwargs = iblrig.misc.get_task_arguments(parents=[Session.extra_parser()])
211+
sess = Session(**kwargs)
212+
sess.run()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
'CONTRAST_SET': [1.0, 0.25, 0.125, 0.0625, 0.0, 0.0, 0.0625, 0.125, 0.25, 1.0] # signed contrast set
2+
'PROBABILITY_SET': [2, 2, 2, 2, 1, 1, 2, 2, 2, 2] # scalar or list of n signed contrasts values, if scalar all contingencies are equiprobable
3+
'REWARD_SET_UL': [1.5] # scalar or list of Ncontrast values
4+
'POSITION_SET': [-35, -35, -35, -35, -35, 35, 35, 35, 35, 35] # position set
5+
'STIM_GAIN': 4.0 # wheel to stimulus relationship
6+
'STIM_REVERSE': False
7+
#'DEBIAS': True # Whether to use debiasing rule or not by repeating error trials # todo
8+
9+
# Opto parameters
10+
'OPTO_TTL_STATES': # list of the state machine states where opto stim should be delivered
11+
- trial_start
12+
'OPTO_STOP_STATES':
13+
- no_go
14+
- error
15+
- reward
16+
'PROBABILITY_OPTO_STIM': 0.2 # probability of optogenetic stimulation
17+
'MAX_LASER_TIME': 6.0
18+
'ESTIMATED_LED_POWER_MW': 2.5
19+
#'MASK_TTL_STATES': # list of the state machine states where mask stim should be delivered
20+
# - trial_start
21+
# - delay_initiation
22+
# - reset_rotary_encoder
23+
# - quiescent_period
24+
# - stim_on
25+
# - interactive_delay
26+
# - play_tone
27+
# - reset2_rotary_encoder
28+
# - closed_loop
29+
# - no_go
30+
# - freeze_error
31+
# - error
32+
# - freeze_reward
33+
# - reward

iblrig_custom_tasks/max_staticTrainingChoiceWorld/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)