55
66import logging
77from pathlib import Path
8+ from typing import Any
89
10+ import numpy as np
911import yaml
12+ from pydantic import NonNegativeFloat
1013
11- from iblrig .base_choice_world import TrainingChoiceWorldSession
1214from iblrig .misc import get_task_arguments
13- from pybpodapi . state_machine import StateMachine
15+ from iblrig_tasks . _iblrig_tasks_trainingChoiceWorld . task import Session as TrainingCWSession
1416
1517log = logging .getLogger ('iblrig.task' )
1618
2022 DEFAULTS = yaml .safe_load (f )
2123
2224
23- class AdaptiveTimeoutStateMachine (StateMachine ):
24-
25- def __init__ (
26- self ,
27- bpod ,
28- adaptive_delay_nogo ,
29- adaptive_delay_error
30- ):
31- super ().__init__ (bpod )
32- self .adaptive_delay_nogo = adaptive_delay_nogo
33- self .adaptive_delay_error = adaptive_delay_error
25+ class AdaptiveTimeoutChoiceWorldTrialData (TrainingCWSession .TrialDataModel ):
26+ adaptive_delay_nogo : NonNegativeFloat
27+ adaptive_delay_error : NonNegativeFloat
3428
3529
36- def add_state (self , ** kwargs ):
37- match kwargs ['state_name' ]:
38- case 'nogo' :
39- pass
40- case 'error' :
41- pass
42- super ().add_state (** kwargs )
43-
44-
45- class Session (TrainingChoiceWorldSession ):
30+ class Session (TrainingCWSession ):
4631 protocol_name = 'nate_adaptiveTimeoutChoiceWorld'
32+ TrialDataModel = AdaptiveTimeoutChoiceWorldTrialData
4733
4834 def __init__ (
4935 self ,
@@ -52,14 +38,36 @@ def __init__(
5238 adaptive_delay_error = DEFAULTS ['ADAPTIVE_FEEDBACK_ERROR_DELAY_SECS' ],
5339 ** kwargs ,
5440 ):
55- self .adaptive_delay_nogo = adaptive_delay_nogo
56- self .adaptive_delay_error = adaptive_delay_error
41+ self ._adaptive_delay_nogo = adaptive_delay_nogo
42+ self ._adaptive_delay_error = adaptive_delay_error
5743 super ().__init__ (* args , ** kwargs )
58- assert len (self .adaptive_delay_nogo ) == len (self .task_params .CONTRAST_SET )
59- assert len (self .adaptive_delay_error ) == len (self .task_params .CONTRAST_SET )
60-
61- def _instantiate_state_machine (self , trial_number = None ):
62- return AdaptiveTimeoutStateMachine (self .bpod , self .adaptive_delay_nogo , self .adaptive_delay_error )
44+ assert len (self ._adaptive_delay_nogo ) == len (self .task_params .CONTRAST_SET )
45+ assert len (self ._adaptive_delay_error ) == len (self .task_params .CONTRAST_SET )
46+
47+ def draw_next_trial_info (self , ** kwargs ):
48+ super ().draw_next_trial_info (** kwargs )
49+ contrast = self .trials_table .at [self .trial_num , 'contrast' ]
50+ index = np .flatnonzero (np .array (self .task_params ['CONTRAST_SET' ]) == contrast )[0 ]
51+ self .trials_table .at [self .trial_num , 'adaptive_delay_nogo' ] = self ._adaptive_delay_nogo [index ]
52+ self .trials_table .at [self .trial_num , 'adaptive_delay_error' ] = self ._adaptive_delay_error [index ]
53+
54+ @property
55+ def feedback_nogo_delay (self ):
56+ return self .trials_table .at [self .trial_num , 'adaptive_delay_nogo' ]
57+
58+ @property
59+ def feedback_error_delay (self ):
60+ return self .trials_table .at [self .trial_num , 'adaptive_delay_error' ]
61+
62+ def show_trial_log (self , extra_info : dict [str , Any ] | None = None , log_level : int = logging .INFO ):
63+ trial_info = self .trials_table .iloc [self .trial_num ]
64+ info_dict = {
65+ 'Adaptive no-go delay' : f'{ trial_info .adaptive_delay_nogo :.2f} s' ,
66+ 'Adaptive error delay' : f'{ trial_info .adaptive_delay_error :.2f} s' ,
67+ }
68+ if isinstance (extra_info , dict ):
69+ info_dict .update (extra_info )
70+ super ().show_trial_log (extra_info = info_dict , log_level = log_level )
6371
6472 @staticmethod
6573 def extra_parser ():
@@ -71,7 +79,7 @@ def extra_parser():
7179 default = DEFAULTS ['ADAPTIVE_FEEDBACK_NOGO_DELAY_SECS' ],
7280 nargs = '+' ,
7381 type = float ,
74- help = 'list of delays for no-go condition (contrasts: 1.0, 0.25, 0.125, 0.0625, 0.0)' ,
82+ help = 'list of delays for no-go condition (contrasts: 1.0, 0.5, 0. 25, 0.125, 0.0625, 0.0)' ,
7583 )
7684 parser .add_argument (
7785 '--adaptive_delay_error' ,
@@ -80,7 +88,7 @@ def extra_parser():
8088 default = DEFAULTS ['ADAPTIVE_FEEDBACK_ERROR_DELAY_SECS' ],
8189 nargs = '+' ,
8290 type = float ,
83- help = 'list of delays for error condition (contrasts: 1.0, 0.25, 0.125, 0.0625, 0.0)' ,
91+ help = 'list of delays for error condition (contrasts: 1.0, 0.5, 0. 25, 0.125, 0.0625, 0.0)' ,
8492 )
8593 return parser
8694
0 commit comments