Skip to content

Commit ad06e38

Browse files
authored
[BugFix] Fix dlpack deprecation in Brax tests (#3123)
1 parent d8dde2e commit ad06e38

File tree

4 files changed

+8
-5
lines changed

4 files changed

+8
-5
lines changed

.github/unittest/linux_libs/scripts_brax/environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,6 @@ dependencies:
1919
- pyyaml
2020
- scipy
2121
- hydra-core
22-
- jax[cuda12]
22+
- jax[cuda12]>=0.7.0
2323
- brax
2424
- psutil

.github/unittest/linux_libs/scripts_minari/environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ dependencies:
2424
- ale-py
2525
- gymnasium-robotics
2626
- minari[create]
27-
- jax
27+
- jax>=0.7.0
2828
- mujoco
2929
- mujoco-py<2.2,>=2.1
3030
- minigrid

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ offline-data = [
7575
]
7676
marl = ["vmas>=1.2.10", "pettingzoo>=1.24.1", "dm-meltingpot"]
7777
open_spiel = ["open_spiel>=1.5"]
78+
brax = ["jax[cuda12]>=0.7.0", "brax"]
7879
llm = [
7980
"transformers",
8081
"vllm",

torchrl/envs/libs/jax_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ def _tree_flatten(x, batch_size: torch.Size):
4242

4343

4444
def _ndarray_to_tensor(value: jnp.ndarray | np.ndarray) -> torch.Tensor: # noqa: F821
45-
from jax import dlpack as jax_dlpack, numpy as jnp
45+
from jax import numpy as jnp
4646

4747
# JAX arrays generated by jax.vmap would have Numpy dtypes.
4848
if value.dtype in _dtype_conversion:
4949
value = value.view(_dtype_conversion[value.dtype])
5050
if isinstance(value, jnp.ndarray):
51-
dlpack_tensor = jax_dlpack.to_dlpack(value)
51+
dlpack_tensor = value.__dlpack__()
5252
elif isinstance(value, np.ndarray):
5353
dlpack_tensor = value.__dlpack__()
5454
else:
@@ -61,7 +61,9 @@ def _ndarray_to_tensor(value: jnp.ndarray | np.ndarray) -> torch.Tensor: # noqa
6161
def _tensor_to_ndarray(value: torch.Tensor) -> jnp.ndarray: # noqa: F821
6262
from jax import dlpack as jax_dlpack
6363

64-
return jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(value.contiguous()))
64+
# Detach the tensor to remove gradients before converting to DLPack
65+
value = value.contiguous().detach()
66+
return jax_dlpack.from_dlpack(value)
6567

6668

6769
def _get_object_fields(obj) -> dict:

0 commit comments

Comments
 (0)