Skip to content

Fix compatibility with flax-basic_example.ipynb after JAX update#6247

Merged
JanuszL merged 3 commits intoNVIDIA:mainfrom
JanuszL:fix_dlpack_jax
Mar 10, 2026
Merged

Fix compatibility with flax-basic_example.ipynb after JAX update#6247
JanuszL merged 3 commits intoNVIDIA:mainfrom
JanuszL:fix_dlpack_jax

Conversation

@JanuszL
Copy link
Contributor

@JanuszL JanuszL commented Mar 9, 2026

  • Replaces deprecated jax.dlpack.to_dlpack() calls with the standard
    tensor.dlpack() method, which is the correct DLPack protocol
    interface for JAX 0.6+.
  • Upgrades flax to 0.10.0
  • Adds ability to create JAX data iterator in in pmap-compatible mode

Category:

Bug fix (non-breaking change which fixes an issue)

Description:

  • Replaces deprecated jax.dlpack.to_dlpack() calls with the standard
    tensor.dlpack() method, which is the correct DLPack protocol
    interface for JAX 0.6+.
  • Upgrades flax to 0.10.0
  • Adds ability to create JAX data iterator in in pmap-compatible mode

Additional information:

Affected modules and functionalities:

  • _function_transform.py
  • setup_packages.py

Key points relevant for the review:

  • NA

Tests:

  • Existing tests apply
    • TL0_jupyter
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A

Checklist

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: N/A

- Replaces deprecated jax.dlpack.to_dlpack() calls with the standard
  tensor.__dlpack__() method, which is the correct DLPack protocol
  interface for JAX 0.6+.

Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [45688155]: BUILD STARTED

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 9, 2026

Greptile Summary

This PR fixes the deprecated jax.dlpack.to_dlpack() call by replacing it with the standard tensor.__dlpack__() protocol method, as required for JAX 0.6+, and bumps flax from 0.7.4 to 0.10.0. It also adds a new pmap_compatible parameter to DALIGenericIterator and the data_iterator decorator that controls whether the iterator output includes a leading device axis for use with jax.pmap.

Key changes:

  • _function_transform.py: gpu_to_dlpack and cpu_to_dlpack now call tensor.__dlpack__() directly instead of the deprecated jax.dlpack.to_dlpack().
  • iterator.py: New pmap_compatible parameter added to DALIGenericIterator.__init__, _data_iterator_impl, and data_iterator; auto-inferred as True when devices is provided, False otherwise.
  • qa/setup_packages.py: flax bumped from 0.7.4 to 0.10.0.
  • clu.py / peekable_data_iterator was not updated to expose pmap_compatible, which means users of peekable_data_iterator with a single devices entry will silently get output with an extra leading device axis ((1, batch, ...)) with no way to opt out.
  • pmap_compatible is not documented in the DALIGenericIterator class-level docstring, only in the data_iterator decorator's docstring.

Confidence Score: 3/5

  • The core DLPack fix is correct and clean, but the pmap_compatible addition has a gap in peekable_data_iterator that can silently break existing users.
  • The _function_transform.py change is a straightforward and correct API migration. The flax version bump is intentional. The pmap_compatible feature in iterator.py is well-designed but incomplete: peekable_data_iterator in clu.py was not updated to expose the parameter, meaning users who call peekable_data_iterator(devices=[single_device]) will see a silent output shape change from (batch, ...) to (1, batch, ...) with no escape hatch. This reduces confidence from an otherwise safe PR.
  • dali/python/nvidia/dali/plugin/jax/iterator.py and dali/python/nvidia/dali/plugin/jax/clu.py need attention: pmap_compatible should be added to peekable_data_iterator and the DALIGenericIterator class docstring should be updated.

Important Files Changed

Filename Overview
dali/python/nvidia/dali/plugin/jax/fn/_function_transform.py Replaces deprecated jax.dlpack.to_dlpack() calls with tensor.__dlpack__() in gpu_to_dlpack and cpu_to_dlpack; the jax.dlpack import is correctly retained for from_dlpack usage elsewhere in the file.
dali/python/nvidia/dali/plugin/jax/iterator.py Adds a new pmap_compatible parameter to DALIGenericIterator, _data_iterator_impl, and data_iterator, with auto-inference logic; the parameter is undocumented in the DALIGenericIterator class docstring, and its counterpart in clu.py (peekable_data_iterator) was not updated, introducing a potential silent behavior change.
qa/setup_packages.py Bumps flax from 0.7.4 to 0.10.0 to align with JAX 0.6+ compatibility requirements; a large version jump but intentional and consistent with the DLPack API migration.

Sequence Diagram

sequenceDiagram
    participant DALI as DALI Pipeline
    participant Transform as _function_transform.py
    participant JAX as JAX Array
    participant DLPack as DLPack Capsule

    Note over DALI,DLPack: GPU path (JAX 0.6+)
    DALI->>Transform: gpu_to_dlpack(tensor, stream)
    Transform->>JAX: tensor.__dlpack__(stream=stream)
    JAX-->>DLPack: capsule
    DLPack-->>DALI: DLPack capsule

    Note over DALI,DLPack: CPU path (JAX 0.6+)
    DALI->>Transform: cpu_to_dlpack(tensor)
    Transform->>JAX: tensor.__dlpack__()
    JAX-->>DLPack: capsule
    DLPack-->>DALI: DLPack capsule

    Note over DALI,DLPack: Import path (unchanged)
    DALI->>Transform: with_gpu/cpu_dl_tensors_as_arrays(callback)
    Transform->>JAX: jax.dlpack.from_dlpack(t)
    JAX-->>Transform: jax.Array
Loading

Comments Outside Diff (2)

  1. dali/python/nvidia/dali/plugin/jax/clu.py, line 286-400 (link)

    pmap_compatible not exposed in peekable_data_iterator

    _data_iterator_impl now auto-enables pmap_compatible=True when devices is provided (line 379 of iterator.py). However, peekable_data_iterator in clu.py does not accept or forward a pmap_compatible argument, so this auto-enable cannot be overridden.

    This introduces a silent breaking change for existing callers of peekable_data_iterator(devices=[single_device]):

    • Before this PR: _num_gpus == 1 and _sharding is None → output shape is (batch_size, ...) (the simple path in _next_impl)
    • After this PR: effective_pmap_compatible = True is auto-set → output shape becomes (1, batch_size, ...) (the stacked path)

    There is no way for peekable_data_iterator users to pass pmap_compatible=False to suppress the new device axis and restore the pre-PR output shape. The parameter should be added to peekable_data_iterator's signature and forwarded to _data_iterator_impl for consistency with data_iterator.

  2. dali/python/nvidia/dali/plugin/jax/iterator.py, line 30-108 (link)

    Missing pmap_compatible in DALIGenericIterator class docstring

    The new pmap_compatible parameter is documented in the data_iterator decorator's docstring (line 489), but is absent from the DALIGenericIterator class-level docstring that describes __init__'s parameters (lines 35–108). Since DALIGenericIterator is a public class and the parameter is part of its constructor signature, the class docstring should document it alongside the other parameters such as sharding.

Last reviewed commit: 092567f

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [45688155]: BUILD PASSED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [45783663]: BUILD FAILED

Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [45784082]: BUILD STARTED

Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [45784725]: BUILD STARTED

@JanuszL JanuszL changed the title Fix JAX DLPack export to use __dlpack__ protocol directly Fix compatibility with flax-basic_example.ipynb after JAX update Mar 10, 2026
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [45784725]: BUILD PASSED

@JanuszL JanuszL merged commit 79f25c2 into NVIDIA:main Mar 10, 2026
6 checks passed
@JanuszL JanuszL deleted the fix_dlpack_jax branch March 10, 2026 20:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants