Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 42 additions & 37 deletions disentangled_rnns/library/two_armed_bandits.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def new_session(self):
"""

@abstractmethod
def step(self, attempted_choice: int) -> tuple[int, float, int]:
def step(self, attempted_choice: int) -> tuple[int, float | int, int]:
"""Executes a single step in the environment.

Args:
Expand Down Expand Up @@ -91,17 +91,18 @@ class EnvironmentBanditsDrift(BaseEnvironment):
n_arms: The number of arms in the environment.
"""

def __init__(self,
sigma: float,
p_instructed: float = 0.0,
seed: Optional[int] = None,
n_arms: int = 2,
):
def __init__(
self,
sigma: float,
p_instructed: float = 0.0,
seed: Optional[int] = None,
n_arms: int = 2,
):
super().__init__(seed=seed, n_arms=n_arms)

# Check inputs
if sigma < 0:
msg = ('sigma was {}, but must be greater than 0')
msg = 'sigma was {}, but must be greater than 0'
raise ValueError(msg.format(sigma))

# Initialize persistent properties
Expand All @@ -116,8 +117,7 @@ def new_session(self):
# Sample randomly between 0 and 1
self._reward_probs = self._random_state.rand(self.n_arms)

def step(self,
attempted_choice: int) -> tuple[int, float, int]:
def step(self, attempted_choice: int) -> tuple[int, float, int]:
"""Run a single trial of the task.

Args:
Expand All @@ -129,7 +129,6 @@ def step(self,
that trial.
reward: The reward to be given to the agent. 0 or 1.
instructed: 1 if the choice was instructed, 0 otherwise

"""
if attempted_choice == -1:
choice = -1
Expand All @@ -139,8 +138,10 @@ def step(self,

# Check inputs
if attempted_choice not in list(range(self.n_arms)):
msg = (f'choice given was {attempted_choice}, but must be one of '
f'{list(range(self.n_arms))}.')
msg = (
f'choice given was {attempted_choice}, but must be one of '
f'{list(range(self.n_arms))}.'
)
raise ValueError(msg)

# If choice was instructed, overrule it and decide randomly
Expand All @@ -154,7 +155,8 @@ def step(self,
reward = self._random_state.rand() < self._reward_probs[choice]
# Add gaussian noise to reward probabilities
drift = self._random_state.normal(
loc=0, scale=self._sigma, size=self.n_arms)
loc=0, scale=self._sigma, size=self.n_arms
)
self._reward_probs += drift

# Fix reward probs that've drifted below 0 or above 1
Expand Down Expand Up @@ -186,7 +188,7 @@ def __init__(
"""Initialize the environment.

Args:
payout_matrix: A numpy array of shape (n_sessions, n_actions, n_trials)
payout_matrix: A numpy array of shape (n_sessions, n_trials, n_actions)
giving the reward for each session, action, and trial. These are
deterministic, i.e. for the same trial_num, session_num, and action, the
reward will always be the same. (If you'd like stochastic rewards you
Expand All @@ -206,7 +208,9 @@ def __init__(
if instructed_matrix is not None:
self._instructed_matrix = instructed_matrix
else:
self._instructed_matrix = np.nan * np.zeros_like(payout_matrix)
self._instructed_matrix = np.full(
(self._n_sessions, self._n_trials), np.nan
)

self._current_session = 0
self._current_trial = 0
Expand All @@ -221,7 +225,7 @@ def new_session(self):
)
self._current_trial = 0

def step(self, attempted_choice: int) -> tuple[int, float, int]:
def step(self, attempted_choice: int) -> tuple[int, float | int, int]:
# If agent choice is default empty value -1, return -1 for all outputs.
if attempted_choice == -1:
choice = -1
Expand All @@ -231,8 +235,10 @@ def step(self, attempted_choice: int) -> tuple[int, float, int]:

# Check inputted choice is valid.
if attempted_choice not in list(range(self.n_arms)):
msg = (f'choice given was {attempted_choice}, but must be one of '
f'{list(range(self.n_arms))}.')
msg = (
f'choice given was {attempted_choice}, but must be one of '
f'{list(range(self.n_arms))}.'
)
raise ValueError(msg)

if self._current_trial >= self._n_trials:
Expand All @@ -256,13 +262,15 @@ def step(self, attempted_choice: int) -> tuple[int, float, int]:
self._current_session, self._current_trial, choice
]
self._current_trial += 1
return choice, float(reward), int(instructed)
return choice, reward.item(), int(instructed)

@property
def payout(self) -> np.ndarray:
"""Get possible payouts for current session, trial across actions."""
return self._payout_matrix[
self._current_session, self._current_trial, :].copy()
self._current_session, self._current_trial, :
].copy()


##########
# AGENTS #
Expand All @@ -274,7 +282,6 @@ class AgentQ:

Attributes:
q: The agent's current estimate of the reward probability on each arm

"""

def __init__(
Expand All @@ -298,7 +305,8 @@ def new_session(self):

def get_choice_probs(self) -> np.ndarray:
choice_probs = np.exp(self._beta * self.q) / np.sum(
np.exp(self._beta * self.q))
np.exp(self._beta * self.q)
)
return choice_probs

def get_choice(self) -> int:
Expand All @@ -308,9 +316,7 @@ def get_choice(self) -> int:
choice = np.random.choice(2, p=choice_probs)
return choice

def update(self,
choice: int,
reward: float):
def update(self, choice: int, reward: float):
"""Update the agent after one step of the task.

Args:
Expand Down Expand Up @@ -350,12 +356,11 @@ def __init__(

def new_session(self):
"""Reset the agent for the beginning of a new session."""
self.theta = 0. * np.ones(2)
self.theta = 0.0 * np.ones(2)
self.v = 0.5

def get_choice_probs(self) -> np.ndarray:
choice_probs = np.exp(self.theta) / np.sum(
np.exp(self.theta))
choice_probs = np.exp(self.theta) / np.sum(np.exp(self.theta))
return choice_probs

def get_choice(self) -> int:
Expand All @@ -379,9 +384,11 @@ def update(self, choice: int, reward: float):
choice_probs = self.get_choice_probs()
rpe = reward - self.v
self.theta[choice] = (1 - self._alpha_actor_forget) * self.theta[
choice] + self._alpha_actor_learn * rpe * (1 - choice_probs[choice])
choice
] + self._alpha_actor_learn * rpe * (1 - choice_probs[choice])
self.theta[unchosen] = (1 - self._alpha_actor_forget) * self.theta[
unchosen] - self._alpha_actor_learn * rpe * (choice_probs[unchosen])
unchosen
] - self._alpha_actor_learn * rpe * (choice_probs[unchosen])

# Critic learing: V moves towards reward
self.v = (1 - self._alpha_critic) * self.v + self._alpha_critic * reward
Expand All @@ -395,9 +402,7 @@ class AgentNetwork:
params: A set of Haiku parameters suitable for that architecture
"""

def __init__(self,
make_network: Callable[[], hk.RNNCore],
params: hk.Params):
def __init__(self, make_network: Callable[[], hk.RNNCore], params: hk.Params):

def step_network(
xs: np.ndarray, state: hk.State
Expand Down Expand Up @@ -449,9 +454,9 @@ class SessData(NamedTuple):
n_trials: int


def run_experiment(agent: Agent,
environment: EnvironmentBanditsDrift,
n_steps: int) -> SessData:
def run_experiment(
agent: Agent, environment: EnvironmentBanditsDrift, n_steps: int
) -> SessData:
"""Runs a behavioral session from a given agent and environment.

Args:
Expand Down