-
Notifications
You must be signed in to change notification settings - Fork 112
Description
What do you want to investigate?
The mava/wrappers/auto_reset_wrapper.py
documentation explicitly warns against using jax.vmap
on an environment wrapped with AutoResetWrapper
:
WARNING: do not
jax.vmap
the wrapped environment (e.g. do not use with theVmapWrapper
), which would lead to inefficient computation due to both thestep
andreset
functions being processed each timestep
is called. Please use theVmapAutoResetWrapper
instead.
However, in various parts of the codebase, environments created by mava/utils/make_env.py
are first wrapped with AutoResetWrapper
(e.g., line 97) and then subsequently jax.vmap
is applied to their step
method within the training loop. A specific example is in mava/systems/ppo/anakin/ff_mappo.py
:
# mava/systems/ppo/anakin/ff_mappo.py Lines 89-90
# Step environment
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)
My hypothesis is that this pattern leads to the inefficient computation mentioned in the AutoResetWrapper
warning because the _auto_reset
function (which calls _env.reset
) is JIT-compiled as part of the vmapped step
function, even though it's only needed for environments that terminate on a given step. Jumanji provides VmapAutoResetWrapper
specifically designed for efficient batch auto-resetting.
We should investigate if using VmapAutoResetWrapper
instead of the current AutoResetWrapper
+ jax.vmap(env.step)
pattern provides performance benefits or if the current implementation is sufficient/intended for some reason.
Definition of done
The investigation will be considered complete when:
- We have confirmed whether the current usage of
AutoResetWrapper
withjax.vmap
leads to the inefficiency described in the wrapper's documentation. - We have determined if replacing
AutoResetWrapper
withjumanji.wrappers.VmapAutoResetWrapper
(potentially removing the explicitjax.vmap
from the training loop step) is feasible and beneficial in the context of Mava's system structure (e.g., inff_mappo.py
). - A decision is made on whether to refactor the code to use
VmapAutoResetWrapper
or to keep the current implementation (with justification).