Skip to content

Commit 79f25c2

Browse files
authored
Fix compatibility with flax-basic_example.ipynb after JAX update (#6247)
- Replaces deprecated jax.dlpack.to_dlpack() calls with the standard tensor.__dlpack__() method, which is the correct DLPack protocol interface for JAX 0.6+. - Upgrades flax to 0.10.0 - Adds ability to create JAX data iterator in in pmap-compatible mode Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
1 parent 86847c2 commit 79f25c2

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

dali/python/nvidia/dali/plugin/jax/fn/_function_transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def gpu_to_dlpack(tensor: jax.Array, stream):
2929
f"The function returned array residing on the device of "
3030
f"kind `{devices[0].platform}`, expected `gpu`."
3131
)
32-
return jax.dlpack.to_dlpack(tensor, stream=stream)
32+
return tensor.__dlpack__(stream=stream)
3333

3434

3535
def cpu_to_dlpack(tensor: jax.Array):
@@ -44,7 +44,7 @@ def cpu_to_dlpack(tensor: jax.Array):
4444
f"The function returned array residing on the device of "
4545
f"kind `{devices[0].platform}`, expected `cpu`."
4646
)
47-
return jax.dlpack.to_dlpack(tensor)
47+
return tensor.__dlpack__()
4848

4949

5050
def with_gpu_dl_tensors_as_arrays(callback):

dali/python/nvidia/dali/plugin/jax/iterator.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def __init__(
118118
last_batch_policy: LastBatchPolicy = LastBatchPolicy.FILL,
119119
prepare_first_batch: bool = True,
120120
sharding: Optional[Sharding] = None,
121+
pmap_compatible: Optional[bool] = None,
121122
):
122123
# check the assert first as _DaliBaseIterator would run the prefetch
123124
if len(set(output_map)) != len(output_map):
@@ -131,6 +132,10 @@ def __init__(
131132
), "`sharding` should be an instance of `jax.sharding.Sharding`"
132133
self._sharding = sharding
133134

135+
# When pmap_compatible is None (default), auto-infer: False for single-pipeline
136+
# iterators. _data_iterator_impl sets True automatically when devices are provided.
137+
self._pmap_compatible = pmap_compatible if pmap_compatible is not None else False
138+
134139
assert (
135140
last_batch_policy != LastBatchPolicy.PARTIAL
136141
), "JAX iterator does not support partial last batch policy."
@@ -170,7 +175,7 @@ def _next_impl(self):
170175
for category_id, category_name in enumerate(self.output_map):
171176
category_outputs = self._gather_outputs_for_category(pipelines_outputs, category_id)
172177

173-
if self._num_gpus == 1 and self._sharding is None:
178+
if self._num_gpus == 1 and self._sharding is None and not self._pmap_compatible:
174179
next_output[category_name] = category_outputs[0]
175180
else:
176181
self._assert_shards_shapes(category_outputs)
@@ -276,6 +281,7 @@ def _data_iterator_impl(
276281
prepare_first_batch: bool = True,
277282
sharding: Optional[Sharding] = None,
278283
devices: Optional[List[jax.Device]] = None,
284+
pmap_compatible: Optional[bool] = None,
279285
):
280286
"""Implementation of the data_iterator decorator. It is extracted to a separate function
281287
to be reused by the peekable iterator decorator.
@@ -309,6 +315,7 @@ def create_iterator(*args, checkpoints=None, **wrapper_kwargs):
309315
last_batch_padded=last_batch_padded,
310316
last_batch_policy=last_batch_policy,
311317
prepare_first_batch=prepare_first_batch,
318+
pmap_compatible=pmap_compatible,
312319
)
313320
else:
314321
pipelines = []
@@ -363,8 +370,14 @@ def create_iterator(*args, checkpoints=None, **wrapper_kwargs):
363370
last_batch_policy=last_batch_policy,
364371
prepare_first_batch=prepare_first_batch,
365372
sharding=sharding,
373+
pmap_compatible=pmap_compatible,
366374
)
367375
elif devices is not None:
376+
# Auto-enable pmap_compatible when devices are provided, unless the user
377+
# explicitly overrode it.
378+
effective_pmap_compatible = (
379+
pmap_compatible if pmap_compatible is not None else True
380+
)
368381
return iterator_type(
369382
pipelines=pipelines,
370383
output_map=output_map,
@@ -374,6 +387,7 @@ def create_iterator(*args, checkpoints=None, **wrapper_kwargs):
374387
last_batch_padded=last_batch_padded,
375388
last_batch_policy=last_batch_policy,
376389
prepare_first_batch=prepare_first_batch,
390+
pmap_compatible=effective_pmap_compatible,
377391
)
378392

379393
raise AssertionError(
@@ -396,6 +410,7 @@ def data_iterator(
396410
prepare_first_batch: bool = True,
397411
sharding: Optional[Sharding] = None,
398412
devices: Optional[List[jax.Device]] = None,
413+
pmap_compatible: Optional[bool] = None,
399414
):
400415
"""Decorator for DALI iterator for JAX. Decorated function when called returns DALI
401416
iterator for JAX.
@@ -471,6 +486,14 @@ def data_iterator(
471486
return outputs compatible with pmapped JAX functions.
472487
This argument is mutually exclusive with `sharding` argument. If `sharding`
473488
is provided, `devices` should be set to None.
489+
pmap_compatible : bool, optional, default = None
490+
Controls whether the iterator produces outputs with a leading device axis
491+
compatible with ``jax.pmap``. When ``None`` (default), it is inferred
492+
automatically: ``True`` when ``devices`` is provided, ``False`` otherwise.
493+
Set to ``True`` explicitly to force pmap-compatible output (shape
494+
``[num_devices, batch_per_device, ...]``) without using the ``devices``
495+
argument. Set to ``False`` to suppress the device axis even when ``devices``
496+
is provided.
474497
checkpoints : list of str, optional, default = None
475498
Checkpoints obtained with `.checkpoints()` method of the iterator.
476499
If provided, they will be used to restore the state of the pipelines.
@@ -506,4 +529,5 @@ def data_iterator(
506529
prepare_first_batch,
507530
sharding,
508531
devices,
532+
pmap_compatible,
509533
)

qa/setup_packages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ def get_pyvers_name(self, url, cuda_version):
654654
"flax",
655655
[
656656
PckgVer(
657-
"0.7.4",
657+
"0.10.0",
658658
# Free-threaded Python build is incompatible with numpy<2.
659659
python_free_threaded=False,
660660
),

0 commit comments

Comments
 (0)