Fix compatibility with flax-basic_example.ipynb after JAX update#6247
Fix compatibility with flax-basic_example.ipynb after JAX update#6247JanuszL merged 3 commits intoNVIDIA:mainfrom
Conversation
- Replaces deprecated jax.dlpack.to_dlpack() calls with the standard tensor.__dlpack__() method, which is the correct DLPack protocol interface for JAX 0.6+. Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
|
CI MESSAGE: [45688155]: BUILD STARTED |
Greptile SummaryThis PR fixes the deprecated Key changes:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant DALI as DALI Pipeline
participant Transform as _function_transform.py
participant JAX as JAX Array
participant DLPack as DLPack Capsule
Note over DALI,DLPack: GPU path (JAX 0.6+)
DALI->>Transform: gpu_to_dlpack(tensor, stream)
Transform->>JAX: tensor.__dlpack__(stream=stream)
JAX-->>DLPack: capsule
DLPack-->>DALI: DLPack capsule
Note over DALI,DLPack: CPU path (JAX 0.6+)
DALI->>Transform: cpu_to_dlpack(tensor)
Transform->>JAX: tensor.__dlpack__()
JAX-->>DLPack: capsule
DLPack-->>DALI: DLPack capsule
Note over DALI,DLPack: Import path (unchanged)
DALI->>Transform: with_gpu/cpu_dl_tensors_as_arrays(callback)
Transform->>JAX: jax.dlpack.from_dlpack(t)
JAX-->>Transform: jax.Array
|
|
CI MESSAGE: [45688155]: BUILD PASSED |
|
CI MESSAGE: [45783663]: BUILD FAILED |
|
CI MESSAGE: [45784082]: BUILD STARTED |
|
CI MESSAGE: [45784725]: BUILD STARTED |
|
CI MESSAGE: [45784725]: BUILD PASSED |
tensor.dlpack() method, which is the correct DLPack protocol
interface for JAX 0.6+.
Category:
Bug fix (non-breaking change which fixes an issue)
Description:
tensor.dlpack() method, which is the correct DLPack protocol
interface for JAX 0.6+.
Additional information:
Affected modules and functionalities:
Key points relevant for the review:
Tests:
Checklist
Documentation
DALI team only
Requirements
REQ IDs: N/A
JIRA TASK: N/A