Skip to content

[INVESTIGATION] Should VmapAutoResetWrapper replace AutoResetWrapper for vmapped environments? #1176

@jselvaraaj

Description

@jselvaraaj

What do you want to investigate?

The mava/wrappers/auto_reset_wrapper.py documentation explicitly warns against using jax.vmap on an environment wrapped with AutoResetWrapper:

WARNING: do not jax.vmap the wrapped environment (e.g. do not use with the VmapWrapper), which would lead to inefficient computation due to both the step and reset functions being processed each time step is called. Please use the VmapAutoResetWrapper instead.

However, in various parts of the codebase, environments created by mava/utils/make_env.py are first wrapped with AutoResetWrapper (e.g., line 97) and then subsequently jax.vmap is applied to their step method within the training loop. A specific example is in mava/systems/ppo/anakin/ff_mappo.py:

# mava/systems/ppo/anakin/ff_mappo.py Lines 89-90
# Step environment
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)

My hypothesis is that this pattern leads to the inefficient computation mentioned in the AutoResetWrapper warning because the _auto_reset function (which calls _env.reset) is JIT-compiled as part of the vmapped step function, even though it's only needed for environments that terminate on a given step. Jumanji provides VmapAutoResetWrapper specifically designed for efficient batch auto-resetting.

We should investigate if using VmapAutoResetWrapper instead of the current AutoResetWrapper + jax.vmap(env.step) pattern provides performance benefits or if the current implementation is sufficient/intended for some reason.

Definition of done

The investigation will be considered complete when:

  1. We have confirmed whether the current usage of AutoResetWrapper with jax.vmap leads to the inefficiency described in the wrapper's documentation.
  2. We have determined if replacing AutoResetWrapper with jumanji.wrappers.VmapAutoResetWrapper (potentially removing the explicit jax.vmap from the training loop step) is feasible and beneficial in the context of Mava's system structure (e.g., in ff_mappo.py).
  3. A decision is made on whether to refactor the code to use VmapAutoResetWrapper or to keep the current implementation (with justification).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions