Skip to content

Commit 04f0c4f

Browse files
GDM Neurolabcopybara-github
authored andcommitted
internal
PiperOrigin-RevId: 835323237
1 parent a128770 commit 04f0c4f

File tree

1 file changed

+54
-45
lines changed

1 file changed

+54
-45
lines changed

disentangled_rnns/library/two_armed_bandits.py

Lines changed: 54 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,18 @@ class EnvironmentBanditsDrift(BaseEnvironment):
9191
n_arms: The number of arms in the environment.
9292
"""
9393

94-
def __init__(self,
95-
sigma: float,
96-
p_instructed: float = 0.0,
97-
seed: Optional[int] = None,
98-
n_arms: int = 2,
99-
):
94+
def __init__(
95+
self,
96+
sigma: float,
97+
p_instructed: float = 0.0,
98+
seed: Optional[int] = None,
99+
n_arms: int = 2,
100+
):
100101
super().__init__(seed=seed, n_arms=n_arms)
101102

102103
# Check inputs
103104
if sigma < 0:
104-
msg = ('sigma was {}, but must be greater than 0')
105+
msg = 'sigma was {}, but must be greater than 0'
105106
raise ValueError(msg.format(sigma))
106107

107108
# Initialize persistent properties
@@ -116,8 +117,7 @@ def new_session(self):
116117
# Sample randomly between 0 and 1
117118
self._reward_probs = self._random_state.rand(self.n_arms)
118119

119-
def step(self,
120-
attempted_choice: int) -> tuple[int, float, int]:
120+
def step(self, attempted_choice: int) -> tuple[int, float, int]:
121121
"""Run a single trial of the task.
122122
123123
Args:
@@ -129,7 +129,6 @@ def step(self,
129129
that trial.
130130
reward: The reward to be given to the agent. 0 or 1.
131131
instructed: 1 if the choice was instructed, 0 otherwise
132-
133132
"""
134133
if attempted_choice == -1:
135134
choice = -1
@@ -139,8 +138,10 @@ def step(self,
139138

140139
# Check inputs
141140
if attempted_choice not in list(range(self.n_arms)):
142-
msg = (f'choice given was {attempted_choice}, but must be one of '
143-
f'{list(range(self.n_arms))}.')
141+
msg = (
142+
f'choice given was {attempted_choice}, but must be one of '
143+
f'{list(range(self.n_arms))}.'
144+
)
144145
raise ValueError(msg)
145146

146147
# If choice was instructed, overrule it and decide randomly
@@ -154,7 +155,8 @@ def step(self,
154155
reward = self._random_state.rand() < self._reward_probs[choice]
155156
# Add gaussian noise to reward probabilities
156157
drift = self._random_state.normal(
157-
loc=0, scale=self._sigma, size=self.n_arms)
158+
loc=0, scale=self._sigma, size=self.n_arms
159+
)
158160
self._reward_probs += drift
159161

160162
# Fix reward probs that've drifted below 0 or above 1
@@ -186,7 +188,7 @@ def __init__(
186188
"""Initialize the environment.
187189
188190
Args:
189-
payout_matrix: A numpy array of shape (n_sessions, n_actions, n_trials)
191+
payout_matrix: A numpy array of shape (n_sessions, n_trials, n_actions)
190192
giving the reward for each session, action, and trial. These are
191193
deterministic, i.e. for the same trial_num, session_num, and action, the
192194
reward will always be the same. (If you'd like stochastic rewards you
@@ -206,7 +208,9 @@ def __init__(
206208
if instructed_matrix is not None:
207209
self._instructed_matrix = instructed_matrix
208210
else:
209-
self._instructed_matrix = np.nan * np.zeros_like(payout_matrix)
211+
self._instructed_matrix = np.full(
212+
(self._n_sessions, self._n_trials), np.nan
213+
)
210214

211215
self._current_session = 0
212216
self._current_trial = 0
@@ -231,8 +235,10 @@ def step(self, attempted_choice: int) -> tuple[int, float, int]:
231235

232236
# Check inputted choice is valid.
233237
if attempted_choice not in list(range(self.n_arms)):
234-
msg = (f'choice given was {attempted_choice}, but must be one of '
235-
f'{list(range(self.n_arms))}.')
238+
msg = (
239+
f'choice given was {attempted_choice}, but must be one of '
240+
f'{list(range(self.n_arms))}.'
241+
)
236242
raise ValueError(msg)
237243

238244
if self._current_trial >= self._n_trials:
@@ -262,7 +268,9 @@ def step(self, attempted_choice: int) -> tuple[int, float, int]:
262268
def payout(self) -> np.ndarray:
263269
"""Get possible payouts for current session, trial across actions."""
264270
return self._payout_matrix[
265-
self._current_session, self._current_trial, :].copy()
271+
self._current_session, self._current_trial, :
272+
].copy()
273+
266274

267275
##########
268276
# AGENTS #
@@ -274,7 +282,6 @@ class AgentQ:
274282
275283
Attributes:
276284
q: The agent's current estimate of the reward probability on each arm
277-
278285
"""
279286

280287
def __init__(
@@ -298,7 +305,8 @@ def new_session(self):
298305

299306
def get_choice_probs(self) -> np.ndarray:
300307
choice_probs = np.exp(self._beta * self.q) / np.sum(
301-
np.exp(self._beta * self.q))
308+
np.exp(self._beta * self.q)
309+
)
302310
return choice_probs
303311

304312
def get_choice(self) -> int:
@@ -308,9 +316,7 @@ def get_choice(self) -> int:
308316
choice = np.random.choice(2, p=choice_probs)
309317
return choice
310318

311-
def update(self,
312-
choice: int,
313-
reward: float):
319+
def update(self, choice: int, reward: float):
314320
"""Update the agent after one step of the task.
315321
316322
Args:
@@ -350,12 +356,11 @@ def __init__(
350356

351357
def new_session(self):
352358
"""Reset the agent for the beginning of a new session."""
353-
self.theta = 0. * np.ones(2)
359+
self.theta = 0.0 * np.ones(2)
354360
self.v = 0.5
355361

356362
def get_choice_probs(self) -> np.ndarray:
357-
choice_probs = np.exp(self.theta) / np.sum(
358-
np.exp(self.theta))
363+
choice_probs = np.exp(self.theta) / np.sum(np.exp(self.theta))
359364
return choice_probs
360365

361366
def get_choice(self) -> int:
@@ -379,9 +384,11 @@ def update(self, choice: int, reward: float):
379384
choice_probs = self.get_choice_probs()
380385
rpe = reward - self.v
381386
self.theta[choice] = (1 - self._alpha_actor_forget) * self.theta[
382-
choice] + self._alpha_actor_learn * rpe * (1 - choice_probs[choice])
387+
choice
388+
] + self._alpha_actor_learn * rpe * (1 - choice_probs[choice])
383389
self.theta[unchosen] = (1 - self._alpha_actor_forget) * self.theta[
384-
unchosen] - self._alpha_actor_learn * rpe * (choice_probs[unchosen])
390+
unchosen
391+
] - self._alpha_actor_learn * rpe * (choice_probs[unchosen])
385392

386393
# Critic learing: V moves towards reward
387394
self.v = (1 - self._alpha_critic) * self.v + self._alpha_critic * reward
@@ -395,9 +402,7 @@ class AgentNetwork:
395402
params: A set of Haiku parameters suitable for that architecture
396403
"""
397404

398-
def __init__(self,
399-
make_network: Callable[[], hk.RNNCore],
400-
params: hk.Params):
405+
def __init__(self, make_network: Callable[[], hk.RNNCore], params: hk.Params):
401406

402407
def step_network(
403408
xs: np.ndarray, state: hk.State
@@ -449,9 +454,9 @@ class SessData(NamedTuple):
449454
n_trials: int
450455

451456

452-
def run_experiment(agent: Agent,
453-
environment: EnvironmentBanditsDrift,
454-
n_steps: int) -> SessData:
457+
def run_experiment(
458+
agent: Agent, environment: EnvironmentBanditsDrift, n_steps: int
459+
) -> SessData:
455460
"""Runs a behavioral session from a given agent and environment.
456461
457462
Args:
@@ -479,25 +484,29 @@ def run_experiment(agent: Agent,
479484
choices[step] = choice
480485
rewards[step] = reward
481486

482-
experiment = SessData(choices=choices,
483-
rewards=rewards,
484-
n_trials=n_steps,
485-
reward_probs=reward_probs)
487+
experiment = SessData(
488+
choices=choices,
489+
rewards=rewards,
490+
n_trials=n_steps,
491+
reward_probs=reward_probs,
492+
)
486493
return experiment
487494

488495

489-
def create_dataset(agent: Agent,
490-
environment: EnvironmentBanditsDrift,
491-
n_steps_per_session: int,
492-
n_sessions: int,
493-
batch_size: int) -> rnn_utils.DatasetRNN:
496+
def create_dataset(
497+
agent: Agent,
498+
environment: EnvironmentBanditsDrift,
499+
n_steps_per_session: int,
500+
n_sessions: int,
501+
batch_size: int,
502+
) -> rnn_utils.DatasetRNN:
494503
"""Generates a behavioral dataset from a given agent and environment.
495504
496505
Args:
497506
agent: An agent object to generate choices
498507
environment: An environment object to generate rewards
499-
n_steps_per_session: The number of trials in each behavioral session to
500-
be generated
508+
n_steps_per_session: The number of trials in each behavioral session to be
509+
generated
501510
n_sessions: The number of sessions to generate
502511
batch_size: The size of the batches to serve from the dataset
503512

0 commit comments

Comments
 (0)