Skip to content

Commit bb3fb17

Browse files
committed
Fix JAX DLPack export to use __dlpack__ protocol directly
- Replaces deprecated jax.dlpack.to_dlpack() calls with the standard tensor.__dlpack__() method, which is the correct DLPack protocol interface for JAX 0.6+. Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
1 parent 86c416d commit bb3fb17

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

dali/python/nvidia/dali/plugin/jax/fn/_function_transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def gpu_to_dlpack(tensor: jax.Array, stream):
2929
f"The function returned array residing on the device of "
3030
f"kind `{devices[0].platform}`, expected `gpu`."
3131
)
32-
return jax.dlpack.to_dlpack(tensor, stream=stream)
32+
return tensor.__dlpack__(stream=stream)
3333

3434

3535
def cpu_to_dlpack(tensor: jax.Array):
@@ -44,7 +44,7 @@ def cpu_to_dlpack(tensor: jax.Array):
4444
f"The function returned array residing on the device of "
4545
f"kind `{devices[0].platform}`, expected `cpu`."
4646
)
47-
return jax.dlpack.to_dlpack(tensor)
47+
return tensor.__dlpack__()
4848

4949

5050
def with_gpu_dl_tensors_as_arrays(callback):

0 commit comments

Comments
 (0)