Skip to content

[MAINTAIN] JAX 0.6.0 not supported #157

@emergenz

Description

@emergenz

I get the following when using Stoix with JAX 0.6.0 because Shape has been removed:

(reward-redistribution) (reward-redistribution) [franz.srambical@gpusrv69 reward-redistribution]$ python stoix/systems/ppo/anakin/ff_ppo.py --logger.use_tb=True --system.epochs=1
Traceback (most recent call last):
  File "/ictstr01/home/aih/franz.srambical/reward-redistribution/stoix/systems/ppo/anakin/ff_ppo.py", line 26, in <module>
    from stoix.evaluator import evaluator_setup, get_distribution_act_fn
  File "/ictstr01/home/aih/franz.srambical/reward-redistribution/stoix/evaluator.py", line 27, in <module>
    from stoix.utils.env_factory import EnvFactory
  File "/ictstr01/home/aih/franz.srambical/reward-redistribution/stoix/utils/env_factory.py", line 9, in <module>
    import envpool
  File "/ictstr01/home/aih/franz.srambical/reward-redistribution/.venv/lib/python3.10/site-packages/envpool/__init__.py", line 16, in <module>
    import envpool.entry  # noqa: F401
  File "/ictstr01/home/aih/franz.srambical/reward-redistribution/.venv/lib/python3.10/site-packages/envpool/entry.py", line 17, in <module>
    import envpool.atari.registration  # noqa: F401
  File "/ictstr01/home/aih/franz.srambical/reward-redistribution/.venv/lib/python3.10/site-packages/envpool/atari/__init__.py", line 21, in <module>
    AtariEnvSpec, AtariDMEnvPool, AtariGymEnvPool, AtariGymnasiumEnvPool = py_env(
  File "/ictstr01/home/aih/franz.srambical/reward-redistribution/.venv/lib/python3.10/site-packages/envpool/python/api.py", line 34, in py_env
    DMEnvPoolMeta(pool_name.replace("EnvPool", "DMEnvPool"), (envpool,), {}),
  File "/ictstr01/home/aih/franz.srambical/reward-redistribution/.venv/lib/python3.10/site-packages/envpool/python/dm_envpool.py", line 52, in __new__
    from .lax import XlaMixin
  File "/ictstr01/home/aih/franz.srambical/reward-redistribution/.venv/lib/python3.10/site-packages/envpool/python/lax.py", line 22, in <module>
    from .xla_template import make_xla
  File "/ictstr01/home/aih/franz.srambical/reward-redistribution/.venv/lib/python3.10/site-packages/envpool/python/xla_template.py", line 30, in <module>
    ) -> Tuple[xla_client.Shape, ...]:
  File "/ictstr01/home/aih/franz.srambical/reward-redistribution/.venv/lib/python3.10/site-packages/jax/_src/deprecations.py", line 54, in getattr
    raise AttributeError(message)
AttributeError: Shape has been removed in JAX v0.6.0; use StableHLO instead.

We should update the requirements.txt accordingly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions