Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dali/python/nvidia/dali/plugin/jax/fn/_function_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def gpu_to_dlpack(tensor: jax.Array, stream):
f"The function returned array residing on the device of "
f"kind `{devices[0].platform}`, expected `gpu`."
)
return jax.dlpack.to_dlpack(tensor, stream=stream)
return tensor.__dlpack__(stream=stream)


def cpu_to_dlpack(tensor: jax.Array):
Expand All @@ -44,7 +44,7 @@ def cpu_to_dlpack(tensor: jax.Array):
f"The function returned array residing on the device of "
f"kind `{devices[0].platform}`, expected `cpu`."
)
return jax.dlpack.to_dlpack(tensor)
return tensor.__dlpack__()


def with_gpu_dl_tensors_as_arrays(callback):
Expand Down
26 changes: 25 additions & 1 deletion dali/python/nvidia/dali/plugin/jax/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
last_batch_policy: LastBatchPolicy = LastBatchPolicy.FILL,
prepare_first_batch: bool = True,
sharding: Optional[Sharding] = None,
pmap_compatible: Optional[bool] = None,
):
# check the assert first as _DaliBaseIterator would run the prefetch
if len(set(output_map)) != len(output_map):
Expand All @@ -131,6 +132,10 @@ def __init__(
), "`sharding` should be an instance of `jax.sharding.Sharding`"
self._sharding = sharding

# When pmap_compatible is None (default), auto-infer: False for single-pipeline
# iterators. _data_iterator_impl sets True automatically when devices are provided.
self._pmap_compatible = pmap_compatible if pmap_compatible is not None else False

assert (
last_batch_policy != LastBatchPolicy.PARTIAL
), "JAX iterator does not support partial last batch policy."
Expand Down Expand Up @@ -170,7 +175,7 @@ def _next_impl(self):
for category_id, category_name in enumerate(self.output_map):
category_outputs = self._gather_outputs_for_category(pipelines_outputs, category_id)

if self._num_gpus == 1 and self._sharding is None:
if self._num_gpus == 1 and self._sharding is None and not self._pmap_compatible:
next_output[category_name] = category_outputs[0]
else:
self._assert_shards_shapes(category_outputs)
Expand Down Expand Up @@ -276,6 +281,7 @@ def _data_iterator_impl(
prepare_first_batch: bool = True,
sharding: Optional[Sharding] = None,
devices: Optional[List[jax.Device]] = None,
pmap_compatible: Optional[bool] = None,
):
"""Implementation of the data_iterator decorator. It is extracted to a separate function
to be reused by the peekable iterator decorator.
Expand Down Expand Up @@ -309,6 +315,7 @@ def create_iterator(*args, checkpoints=None, **wrapper_kwargs):
last_batch_padded=last_batch_padded,
last_batch_policy=last_batch_policy,
prepare_first_batch=prepare_first_batch,
pmap_compatible=pmap_compatible,
)
else:
pipelines = []
Expand Down Expand Up @@ -363,8 +370,14 @@ def create_iterator(*args, checkpoints=None, **wrapper_kwargs):
last_batch_policy=last_batch_policy,
prepare_first_batch=prepare_first_batch,
sharding=sharding,
pmap_compatible=pmap_compatible,
)
elif devices is not None:
# Auto-enable pmap_compatible when devices are provided, unless the user
# explicitly overrode it.
effective_pmap_compatible = (
pmap_compatible if pmap_compatible is not None else True
)
return iterator_type(
pipelines=pipelines,
output_map=output_map,
Expand All @@ -374,6 +387,7 @@ def create_iterator(*args, checkpoints=None, **wrapper_kwargs):
last_batch_padded=last_batch_padded,
last_batch_policy=last_batch_policy,
prepare_first_batch=prepare_first_batch,
pmap_compatible=effective_pmap_compatible,
)

raise AssertionError(
Expand All @@ -396,6 +410,7 @@ def data_iterator(
prepare_first_batch: bool = True,
sharding: Optional[Sharding] = None,
devices: Optional[List[jax.Device]] = None,
pmap_compatible: Optional[bool] = None,
):
"""Decorator for DALI iterator for JAX. Decorated function when called returns DALI
iterator for JAX.
Expand Down Expand Up @@ -471,6 +486,14 @@ def data_iterator(
return outputs compatible with pmapped JAX functions.
This argument is mutually exclusive with `sharding` argument. If `sharding`
is provided, `devices` should be set to None.
pmap_compatible : bool, optional, default = None
Controls whether the iterator produces outputs with a leading device axis
compatible with ``jax.pmap``. When ``None`` (default), it is inferred
automatically: ``True`` when ``devices`` is provided, ``False`` otherwise.
Set to ``True`` explicitly to force pmap-compatible output (shape
``[num_devices, batch_per_device, ...]``) without using the ``devices``
argument. Set to ``False`` to suppress the device axis even when ``devices``
is provided.
checkpoints : list of str, optional, default = None
Checkpoints obtained with `.checkpoints()` method of the iterator.
If provided, they will be used to restore the state of the pipelines.
Expand Down Expand Up @@ -506,4 +529,5 @@ def data_iterator(
prepare_first_batch,
sharding,
devices,
pmap_compatible,
)
2 changes: 1 addition & 1 deletion qa/setup_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def get_pyvers_name(self, url, cuda_version):
"flax",
[
PckgVer(
"0.7.4",
"0.10.0",
# Free-threaded Python build is incompatible with numpy<2.
python_free_threaded=False,
),
Expand Down