Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,30 @@ def step_env_and_evaluator(
output = output.replace(eval_state=eval_state)
return output, env_state, env_state_metadata, terminated, truncated, rewards

@dataclass(frozen=True)
class SinglePlayerGameState:
"""Stores the state of a single-player game for two evaluators playing independently.
- `key`: rng
- `env_state`: The initial environment state.
- `env_state_metadata`: Metadata associated with the initial environment state.
- `eval_state_1`: The internal state of the first evaluator.
- `eval_state_2`: The internal state of the second evaluator.
- `completed_1`: Whether the first evaluator's game is completed.
- `completed_2`: Whether the second evaluator's game is completed.
- `outcome_1`: The final reward of the first evaluator.
- `outcome_2`: The final reward of the second evaluator.
"""
key: jax.random.PRNGKey
env_state_1: chex.ArrayTree
env_state_2: chex.ArrayTree
env_state_metadata_1: StepMetadata
env_state_metadata_2: StepMetadata
eval_state_1: chex.ArrayTree
eval_state_2: chex.ArrayTree
completed_1: bool
completed_2: bool
outcome_1: float
outcome_2: float

@dataclass(frozen=True)
class TwoPlayerGameState:
Expand Down Expand Up @@ -142,6 +166,136 @@ class GameFrame:
completed: chex.Array
outcomes: chex.Array

def single_player_play(
key: jax.random.PRNGKey,
env_state: chex.ArrayTree,
env_state_metadata: StepMetadata,
eval_state: chex.ArrayTree,
evaluator: Evaluator,
params: chex.ArrayTree,
env_step_fn: EnvStepFn,
max_steps: int
) -> Tuple[chex.ArrayTree, StepMetadata, chex.ArrayTree, bool, float]:
"""
Executes a single player game until completion.

Args:
- `key`: RNG.
- `env_state`: Initial environment state.
- `env_state_metadata`: Metadata for the environment state.
- `eval_state`: Evaluator's internal state.
- `evaluator`: Evaluator to step through the environment.
- `params`: Evaluator's parameters.
- `env_step_fn`: Function to step the environment.
- `max_steps`: Maximum steps to take.

Returns:
- `env_state`: Final environment state.
- `env_state_metadata`: Final metadata.
- `eval_state`: Final evaluator state.
- `completed`: Whether the episode is completed.
- `outcome`: The total reward for the game.
"""
def step_fn(carry, _):
env_state, env_metadata, eval_state, key = carry
step_key, key = jax.random.split(key)

# Evaluate and take action
output = evaluator.evaluate(
key=step_key,
eval_state=eval_state,
env_state=env_state,
root_metadata=env_metadata,
params=params,
env_step_fn=env_step_fn
)
next_env_state, next_env_metadata = env_step_fn(env_state, output.action)
terminated = next_env_metadata.terminated
truncated = next_env_metadata.step > max_steps
completed = terminated | truncated
rewards = next_env_metadata.rewards
eval_state = jax.lax.cond(
completed,
lambda _: eval_state,
lambda _: evaluator.step(eval_state, output.action),
None
)
return (next_env_state, next_env_metadata, eval_state, step_key), (completed, rewards)

# Scan through the steps
(env_state, final_metadata, final_eval_state, _), (completed, rewards) = jax.lax.scan(
step_fn,
(env_state, env_state_metadata, eval_state, key),
xs=None,
length=max_steps
)

# Determine the outcome
final_outcome = jnp.where(completed, rewards.flatten(), 0.0)
final_completed = jnp.any(completed)
return env_state, final_metadata, final_eval_state, final_completed, final_outcome

def single_player_game(
key: jax.random.PRNGKey,
evaluator_1: Evaluator,
evaluator_2: Evaluator,
params_1: chex.ArrayTree,
params_2: chex.ArrayTree,
env_step_fn: EnvStepFn,
env_init_fn: EnvInitFn,
max_steps: int
) -> Tuple[chex.Array, SinglePlayerGameState, chex.Array]:
"""
Simulates a single-player game for two evaluators playing independently.

Args:
- `key`: RNG.
- `evaluator_1`: The first evaluator.
- `evaluator_2`: The second evaluator.
- `params_1`: Parameters for the first evaluator.
- `params_2`: Parameters for the second evaluator.
- `env_step_fn`: The environment step function.
- `env_init_fn`: The environment initialization function.
- `max_steps`: Maximum number of steps.

Returns:
- (chex.Array): A tuple of final outcomes for both evaluators.
- (SinglePlayerGameState): The final state of the single-player game.
"""
# Initialize RNG
init_key, key_1, key_2 = jax.random.split(key, 3)

# Initialize environment and evaluators
env_state, metadata = env_init_fn(init_key)
eval_state_1 = evaluator_1.init(template_embedding=env_state)
eval_state_2 = evaluator_2.init(template_embedding=env_state)

# No compilation as function is called only once
# Evaluate the game for both evaluators independently
env_state_1, metadata_1, eval_state_1, completed_1, outcome_1 = single_player_play(
key_1, env_state, metadata, eval_state_1, evaluator_1, params_1, env_step_fn, max_steps
)
env_state_2, metadata_2, eval_state_2, completed_2, outcome_2 = single_player_play(
key_2, env_state, metadata, eval_state_2, evaluator_2, params_2, env_step_fn, max_steps
)

# Create the final game state
game_state = SinglePlayerGameState(
key=key,
env_state_1=env_state_1,
env_state_2=env_state_2,
env_state_metadata_1=metadata_1,
env_state_metadata_2=metadata_2,
eval_state_1=eval_state_1,
eval_state_2=eval_state_2,
completed_1=completed_1,
completed_2=completed_2,
outcome_1=outcome_1,
outcome_2=outcome_2
)

return jnp.array([outcome_1, outcome_2]), game_state, jnp.array([0, 1])


def two_player_game_step(
state: TwoPlayerGameState,
Expand Down
108 changes: 108 additions & 0 deletions core/testing/single_player_tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from functools import partial
from typing import Dict, Tuple

import chex
from chex import dataclass
import jax
import jax.numpy as jnp

from core.common import single_player_game
from core.evaluators.evaluator import Evaluator
from core.testing.tester import BaseTester, TestState
from core.types import EnvInitFn, EnvStepFn


@dataclass(frozen=True)
class SinglePlayerTestState(TestState):
"""Internal state of a SinglePlayerTester. Stores the best parameters found so far.
- `best_params`: best performing parameters
"""
best_params: chex.ArrayTree


class SinglePlayerTester(BaseTester):
"""Implements a tester that evaluates an agent against the best performing parameters
found so far in a two-player game."""
def __init__(self, num_episodes: int, *args, **kwargs):
"""
Args:
- `num_episodes`: number of episodes to play in each test
"""
super().__init__(*args, num_keys=num_episodes, **kwargs)
self.num_episodes = num_episodes


def init(self, params: chex.ArrayTree, **kwargs) -> SinglePlayerTestState: #pylint: disable=unused-argument
"""Initializes the internal state of the TwoPlayerTester.
Args:
- `params`: initial parameters to store as the best performing
- can just be the initial parameters of the agent
"""
return SinglePlayerTestState(best_params=params)


def check_size_compatibilities(self, num_devices: int) -> None:
"""Checks if tester configuration is compatible with number of devices being utilized.

Args:
- `num_devices`: number of devices
"""
if self.num_episodes % num_devices != 0:
raise ValueError(f"{self.__class__.__name__}: number of episodes ({self.num_episodes}) must be divisible by number of devices ({num_devices})")


@partial(jax.pmap, axis_name='d', static_broadcasted_argnums=(0, 1, 2, 3, 4))
def test(self, max_steps: int, env_step_fn: EnvStepFn, env_init_fn: EnvInitFn, evaluator: Evaluator,
keys: chex.PRNGKey, state: SinglePlayerTestState, params: chex.ArrayTree) -> Tuple[SinglePlayerTestState, Dict, chex.ArrayTree, chex.Array]:
"""Test the agent against the best performing parameters found so far in a single-player game.

Args:
- `max_steps`: maximum number of steps per episode
- `env_step_fn`: environment step function
- `env_init_fn`: environment initialization function
- `evaluator`: the agent evaluator
- `keys`: rng
- `state`: internal state of the tester
- `params`: nn parameters used by agent

Returns:
- (SinglePlayerTestState, Dict, chex.ArrayTree, chex.Array)
- updated internal state of the tester
- metrics from the test
- frames from the test (used for rendering)
- player ids from the test (used for rendering)
"""

game_fn = partial(single_player_game,
evaluator_1 = evaluator,
evaluator_2 = evaluator,
params_1 = params,
params_2 = state.best_params,
env_step_fn = env_step_fn,
env_init_fn = env_init_fn,
max_steps = max_steps
)

results, frames, p_ids = jax.vmap(game_fn)(keys)
frames = jax.tree_map(lambda x: x[0], frames)
p_ids = p_ids[0]

num = keys.shape[-2]
new = jnp.sum(results[...,0,:]) / num
best = jnp.sum(results[...,1,:]) / num
avg = new/best

metrics = {
f"{self.name}_new_params": new,
f"{self.name}_best_params": best,
f"{self.name}_ratio": avg,
}

best_params = jax.lax.cond(
new > best,
lambda _: params,
lambda _: state.best_params,
None
)

return state.replace(best_params=best_params), metrics, frames, p_ids