Skip to content

Commit 2336277

Browse files
committed
Update task.py
1 parent 0911d7e commit 2336277

File tree

1 file changed

+40
-32
lines changed
  • iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld

1 file changed

+40
-32
lines changed

iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55

66
import logging
77
from pathlib import Path
8+
from typing import Any
89

10+
import numpy as np
911
import yaml
12+
from pydantic import NonNegativeFloat
1013

11-
from iblrig.base_choice_world import TrainingChoiceWorldSession
1214
from iblrig.misc import get_task_arguments
13-
from pybpodapi.state_machine import StateMachine
15+
from iblrig_tasks._iblrig_tasks_trainingChoiceWorld.task import Session as TrainingCWSession
1416

1517
log = logging.getLogger('iblrig.task')
1618

@@ -20,30 +22,14 @@
2022
DEFAULTS = yaml.safe_load(f)
2123

2224

23-
class AdaptiveTimeoutStateMachine(StateMachine):
24-
25-
def __init__(
26-
self,
27-
bpod,
28-
adaptive_delay_nogo,
29-
adaptive_delay_error
30-
):
31-
super().__init__(bpod)
32-
self.adaptive_delay_nogo = adaptive_delay_nogo
33-
self.adaptive_delay_error = adaptive_delay_error
25+
class AdaptiveTimeoutChoiceWorldTrialData(TrainingCWSession.TrialDataModel):
26+
adaptive_delay_nogo: NonNegativeFloat
27+
adaptive_delay_error: NonNegativeFloat
3428

3529

36-
def add_state(self, **kwargs):
37-
match kwargs['state_name']:
38-
case 'nogo':
39-
pass
40-
case 'error':
41-
pass
42-
super().add_state(**kwargs)
43-
44-
45-
class Session(TrainingChoiceWorldSession):
30+
class Session(TrainingCWSession):
4631
protocol_name = 'nate_adaptiveTimeoutChoiceWorld'
32+
TrialDataModel = AdaptiveTimeoutChoiceWorldTrialData
4733

4834
def __init__(
4935
self,
@@ -52,14 +38,36 @@ def __init__(
5238
adaptive_delay_error=DEFAULTS['ADAPTIVE_FEEDBACK_ERROR_DELAY_SECS'],
5339
**kwargs,
5440
):
55-
self.adaptive_delay_nogo = adaptive_delay_nogo
56-
self.adaptive_delay_error = adaptive_delay_error
41+
self._adaptive_delay_nogo = adaptive_delay_nogo
42+
self._adaptive_delay_error = adaptive_delay_error
5743
super().__init__(*args, **kwargs)
58-
assert len(self.adaptive_delay_nogo) == len(self.task_params.CONTRAST_SET)
59-
assert len(self.adaptive_delay_error) == len(self.task_params.CONTRAST_SET)
60-
61-
def _instantiate_state_machine(self, trial_number=None):
62-
return AdaptiveTimeoutStateMachine(self.bpod, self.adaptive_delay_nogo, self.adaptive_delay_error)
44+
assert len(self._adaptive_delay_nogo) == len(self.task_params.CONTRAST_SET)
45+
assert len(self._adaptive_delay_error) == len(self.task_params.CONTRAST_SET)
46+
47+
def draw_next_trial_info(self, **kwargs):
48+
super().draw_next_trial_info(**kwargs)
49+
contrast = self.trials_table.at[self.trial_num, 'contrast']
50+
index = np.flatnonzero(np.array(self.task_params['CONTRAST_SET']) == contrast)[0]
51+
self.trials_table.at[self.trial_num, 'adaptive_delay_nogo'] = self._adaptive_delay_nogo[index]
52+
self.trials_table.at[self.trial_num, 'adaptive_delay_error'] = self._adaptive_delay_error[index]
53+
54+
@property
55+
def feedback_nogo_delay(self):
56+
return self.trials_table.at[self.trial_num, 'adaptive_delay_nogo']
57+
58+
@property
59+
def feedback_error_delay(self):
60+
return self.trials_table.at[self.trial_num, 'adaptive_delay_error']
61+
62+
def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
63+
trial_info = self.trials_table.iloc[self.trial_num]
64+
info_dict = {
65+
'Adaptive no-go delay': f'{trial_info.adaptive_delay_nogo:.2f} s',
66+
'Adaptive error delay': f'{trial_info.adaptive_delay_error:.2f} s',
67+
}
68+
if isinstance(extra_info, dict):
69+
info_dict.update(extra_info)
70+
super().show_trial_log(extra_info=info_dict, log_level=log_level)
6371

6472
@staticmethod
6573
def extra_parser():
@@ -71,7 +79,7 @@ def extra_parser():
7179
default=DEFAULTS['ADAPTIVE_FEEDBACK_NOGO_DELAY_SECS'],
7280
nargs='+',
7381
type=float,
74-
help='list of delays for no-go condition (contrasts: 1.0, 0.25, 0.125, 0.0625, 0.0)',
82+
help='list of delays for no-go condition (contrasts: 1.0, 0.5, 0.25, 0.125, 0.0625, 0.0)',
7583
)
7684
parser.add_argument(
7785
'--adaptive_delay_error',
@@ -80,7 +88,7 @@ def extra_parser():
8088
default=DEFAULTS['ADAPTIVE_FEEDBACK_ERROR_DELAY_SECS'],
8189
nargs='+',
8290
type=float,
83-
help='list of delays for error condition (contrasts: 1.0, 0.25, 0.125, 0.0625, 0.0)',
91+
help='list of delays for error condition (contrasts: 1.0, 0.5, 0.25, 0.125, 0.0625, 0.0)',
8492
)
8593
return parser
8694

0 commit comments

Comments
 (0)