@@ -91,17 +91,18 @@ class EnvironmentBanditsDrift(BaseEnvironment):
9191 n_arms: The number of arms in the environment.
9292 """
9393
94- def __init__ (self ,
95- sigma : float ,
96- p_instructed : float = 0.0 ,
97- seed : Optional [int ] = None ,
98- n_arms : int = 2 ,
99- ):
94+ def __init__ (
95+ self ,
96+ sigma : float ,
97+ p_instructed : float = 0.0 ,
98+ seed : Optional [int ] = None ,
99+ n_arms : int = 2 ,
100+ ):
100101 super ().__init__ (seed = seed , n_arms = n_arms )
101102
102103 # Check inputs
103104 if sigma < 0 :
104- msg = ( 'sigma was {}, but must be greater than 0' )
105+ msg = 'sigma was {}, but must be greater than 0'
105106 raise ValueError (msg .format (sigma ))
106107
107108 # Initialize persistent properties
@@ -116,8 +117,7 @@ def new_session(self):
116117 # Sample randomly between 0 and 1
117118 self ._reward_probs = self ._random_state .rand (self .n_arms )
118119
119- def step (self ,
120- attempted_choice : int ) -> tuple [int , float , int ]:
120+ def step (self , attempted_choice : int ) -> tuple [int , float , int ]:
121121 """Run a single trial of the task.
122122
123123 Args:
@@ -129,7 +129,6 @@ def step(self,
129129 that trial.
130130 reward: The reward to be given to the agent. 0 or 1.
131131 instructed: 1 if the choice was instructed, 0 otherwise
132-
133132 """
134133 if attempted_choice == - 1 :
135134 choice = - 1
@@ -139,8 +138,10 @@ def step(self,
139138
140139 # Check inputs
141140 if attempted_choice not in list (range (self .n_arms )):
142- msg = (f'choice given was { attempted_choice } , but must be one of '
143- f'{ list (range (self .n_arms ))} .' )
141+ msg = (
142+ f'choice given was { attempted_choice } , but must be one of '
143+ f'{ list (range (self .n_arms ))} .'
144+ )
144145 raise ValueError (msg )
145146
146147 # If choice was instructed, overrule it and decide randomly
@@ -154,7 +155,8 @@ def step(self,
154155 reward = self ._random_state .rand () < self ._reward_probs [choice ]
155156 # Add gaussian noise to reward probabilities
156157 drift = self ._random_state .normal (
157- loc = 0 , scale = self ._sigma , size = self .n_arms )
158+ loc = 0 , scale = self ._sigma , size = self .n_arms
159+ )
158160 self ._reward_probs += drift
159161
160162 # Fix reward probs that've drifted below 0 or above 1
@@ -186,7 +188,7 @@ def __init__(
186188 """Initialize the environment.
187189
188190 Args:
189- payout_matrix: A numpy array of shape (n_sessions, n_actions, n_trials )
191+ payout_matrix: A numpy array of shape (n_sessions, n_trials, n_actions )
190192 giving the reward for each session, action, and trial. These are
191193 deterministic, i.e. for the same trial_num, session_num, and action, the
192194 reward will always be the same. (If you'd like stochastic rewards you
@@ -206,7 +208,9 @@ def __init__(
206208 if instructed_matrix is not None :
207209 self ._instructed_matrix = instructed_matrix
208210 else :
209- self ._instructed_matrix = np .nan * np .zeros_like (payout_matrix )
211+ self ._instructed_matrix = np .full (
212+ (self ._n_sessions , self ._n_trials ), np .nan
213+ )
210214
211215 self ._current_session = 0
212216 self ._current_trial = 0
@@ -231,8 +235,10 @@ def step(self, attempted_choice: int) -> tuple[int, float, int]:
231235
232236 # Check inputted choice is valid.
233237 if attempted_choice not in list (range (self .n_arms )):
234- msg = (f'choice given was { attempted_choice } , but must be one of '
235- f'{ list (range (self .n_arms ))} .' )
238+ msg = (
239+ f'choice given was { attempted_choice } , but must be one of '
240+ f'{ list (range (self .n_arms ))} .'
241+ )
236242 raise ValueError (msg )
237243
238244 if self ._current_trial >= self ._n_trials :
@@ -262,7 +268,9 @@ def step(self, attempted_choice: int) -> tuple[int, float, int]:
262268 def payout (self ) -> np .ndarray :
263269 """Get possible payouts for current session, trial across actions."""
264270 return self ._payout_matrix [
265- self ._current_session , self ._current_trial , :].copy ()
271+ self ._current_session , self ._current_trial , :
272+ ].copy ()
273+
266274
267275##########
268276# AGENTS #
@@ -274,7 +282,6 @@ class AgentQ:
274282
275283 Attributes:
276284 q: The agent's current estimate of the reward probability on each arm
277-
278285 """
279286
280287 def __init__ (
@@ -298,7 +305,8 @@ def new_session(self):
298305
299306 def get_choice_probs (self ) -> np .ndarray :
300307 choice_probs = np .exp (self ._beta * self .q ) / np .sum (
301- np .exp (self ._beta * self .q ))
308+ np .exp (self ._beta * self .q )
309+ )
302310 return choice_probs
303311
304312 def get_choice (self ) -> int :
@@ -308,9 +316,7 @@ def get_choice(self) -> int:
308316 choice = np .random .choice (2 , p = choice_probs )
309317 return choice
310318
311- def update (self ,
312- choice : int ,
313- reward : float ):
319+ def update (self , choice : int , reward : float ):
314320 """Update the agent after one step of the task.
315321
316322 Args:
@@ -350,12 +356,11 @@ def __init__(
350356
351357 def new_session (self ):
352358 """Reset the agent for the beginning of a new session."""
353- self .theta = 0. * np .ones (2 )
359+ self .theta = 0.0 * np .ones (2 )
354360 self .v = 0.5
355361
356362 def get_choice_probs (self ) -> np .ndarray :
357- choice_probs = np .exp (self .theta ) / np .sum (
358- np .exp (self .theta ))
363+ choice_probs = np .exp (self .theta ) / np .sum (np .exp (self .theta ))
359364 return choice_probs
360365
361366 def get_choice (self ) -> int :
@@ -379,9 +384,11 @@ def update(self, choice: int, reward: float):
379384 choice_probs = self .get_choice_probs ()
380385 rpe = reward - self .v
381386 self .theta [choice ] = (1 - self ._alpha_actor_forget ) * self .theta [
382- choice ] + self ._alpha_actor_learn * rpe * (1 - choice_probs [choice ])
387+ choice
388+ ] + self ._alpha_actor_learn * rpe * (1 - choice_probs [choice ])
383389 self .theta [unchosen ] = (1 - self ._alpha_actor_forget ) * self .theta [
384- unchosen ] - self ._alpha_actor_learn * rpe * (choice_probs [unchosen ])
390+ unchosen
391+ ] - self ._alpha_actor_learn * rpe * (choice_probs [unchosen ])
385392
386393 # Critic learing: V moves towards reward
387394 self .v = (1 - self ._alpha_critic ) * self .v + self ._alpha_critic * reward
@@ -395,9 +402,7 @@ class AgentNetwork:
395402 params: A set of Haiku parameters suitable for that architecture
396403 """
397404
398- def __init__ (self ,
399- make_network : Callable [[], hk .RNNCore ],
400- params : hk .Params ):
405+ def __init__ (self , make_network : Callable [[], hk .RNNCore ], params : hk .Params ):
401406
402407 def step_network (
403408 xs : np .ndarray , state : hk .State
@@ -449,9 +454,9 @@ class SessData(NamedTuple):
449454 n_trials : int
450455
451456
452- def run_experiment (agent : Agent ,
453- environment : EnvironmentBanditsDrift ,
454- n_steps : int ) -> SessData :
457+ def run_experiment (
458+ agent : Agent , environment : EnvironmentBanditsDrift , n_steps : int
459+ ) -> SessData :
455460 """Runs a behavioral session from a given agent and environment.
456461
457462 Args:
@@ -479,25 +484,29 @@ def run_experiment(agent: Agent,
479484 choices [step ] = choice
480485 rewards [step ] = reward
481486
482- experiment = SessData (choices = choices ,
483- rewards = rewards ,
484- n_trials = n_steps ,
485- reward_probs = reward_probs )
487+ experiment = SessData (
488+ choices = choices ,
489+ rewards = rewards ,
490+ n_trials = n_steps ,
491+ reward_probs = reward_probs ,
492+ )
486493 return experiment
487494
488495
489- def create_dataset (agent : Agent ,
490- environment : EnvironmentBanditsDrift ,
491- n_steps_per_session : int ,
492- n_sessions : int ,
493- batch_size : int ) -> rnn_utils .DatasetRNN :
496+ def create_dataset (
497+ agent : Agent ,
498+ environment : EnvironmentBanditsDrift ,
499+ n_steps_per_session : int ,
500+ n_sessions : int ,
501+ batch_size : int ,
502+ ) -> rnn_utils .DatasetRNN :
494503 """Generates a behavioral dataset from a given agent and environment.
495504
496505 Args:
497506 agent: An agent object to generate choices
498507 environment: An environment object to generate rewards
499- n_steps_per_session: The number of trials in each behavioral session to
500- be generated
508+ n_steps_per_session: The number of trials in each behavioral session to be
509+ generated
501510 n_sessions: The number of sessions to generate
502511 batch_size: The size of the batches to serve from the dataset
503512
0 commit comments