diff --git a/.github/unittest/linux_libs/scripts_brax/environment.yml b/.github/unittest/linux_libs/scripts_brax/environment.yml index 695278f3adb..63e26428d55 100644 --- a/.github/unittest/linux_libs/scripts_brax/environment.yml +++ b/.github/unittest/linux_libs/scripts_brax/environment.yml @@ -19,6 +19,6 @@ dependencies: - pyyaml - scipy - hydra-core - - jax[cuda12] + - jax[cuda12]>=0.7.0 - brax - psutil diff --git a/.github/unittest/linux_libs/scripts_minari/environment.yml b/.github/unittest/linux_libs/scripts_minari/environment.yml index e1362721b12..9c310d805f5 100644 --- a/.github/unittest/linux_libs/scripts_minari/environment.yml +++ b/.github/unittest/linux_libs/scripts_minari/environment.yml @@ -24,7 +24,7 @@ dependencies: - ale-py - gymnasium-robotics - minari[create] - - jax + - jax>=0.7.0 - mujoco - mujoco-py<2.2,>=2.1 - minigrid diff --git a/pyproject.toml b/pyproject.toml index 9a6014d269d..ca7d962d7de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ offline-data = [ ] marl = ["vmas>=1.2.10", "pettingzoo>=1.24.1", "dm-meltingpot"] open_spiel = ["open_spiel>=1.5"] +brax = ["jax[cuda12]>=0.7.0", "brax"] llm = [ "transformers", "vllm", diff --git a/torchrl/envs/libs/jax_utils.py b/torchrl/envs/libs/jax_utils.py index 337e9e9a1a9..935b87460de 100644 --- a/torchrl/envs/libs/jax_utils.py +++ b/torchrl/envs/libs/jax_utils.py @@ -42,13 +42,13 @@ def _tree_flatten(x, batch_size: torch.Size): def _ndarray_to_tensor(value: jnp.ndarray | np.ndarray) -> torch.Tensor: # noqa: F821 - from jax import dlpack as jax_dlpack, numpy as jnp + from jax import numpy as jnp # JAX arrays generated by jax.vmap would have Numpy dtypes. if value.dtype in _dtype_conversion: value = value.view(_dtype_conversion[value.dtype]) if isinstance(value, jnp.ndarray): - dlpack_tensor = jax_dlpack.to_dlpack(value) + dlpack_tensor = value.__dlpack__() elif isinstance(value, np.ndarray): dlpack_tensor = value.__dlpack__() else: @@ -61,7 +61,9 @@ def _ndarray_to_tensor(value: jnp.ndarray | np.ndarray) -> torch.Tensor: # noqa def _tensor_to_ndarray(value: torch.Tensor) -> jnp.ndarray: # noqa: F821 from jax import dlpack as jax_dlpack - return jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(value.contiguous())) + # Detach the tensor to remove gradients before converting to DLPack + value = value.contiguous().detach() + return jax_dlpack.from_dlpack(value) def _get_object_fields(obj) -> dict: