Skip to content

Conversation

@alexunderch
Copy link

Usage example

import jax

from pogema import GridConfig, pogema_v0

grid_config = GridConfig(
    size=8,
    num_agents=5,
    obs_radius=2,
    seed=9,
    on_target="finish",
    max_episode_steps=128,
    integration="jax",
)

env = pogema_v0(grid_config=grid_config)


key = jax.random.key(0)

# resetting
state, env_state = env.reset(key)

policy = lambda rng: jax.random.randint(
    rng, (env.num_agents,), minval=0, maxval=env.action_space().n
)

# iteration


def step_fn(carry, _):
    state, env_state, step_key = carry
    act_key, key = jax.random.split(step_key)
    action = policy(act_key)  # random agent
    next_state, new_env_state, reward, terminated, truncated, info = env.step(
        action, env_state
    )
    return (
        (next_state, new_env_state, key),
        (state, next_state, action, reward, terminated, truncated, info),
    )


_, rollout = jax.lax.scan(step_fn, (state, env_state, key), None, length=70)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant