@@ -118,6 +118,7 @@ def __init__(
118118 last_batch_policy : LastBatchPolicy = LastBatchPolicy .FILL ,
119119 prepare_first_batch : bool = True ,
120120 sharding : Optional [Sharding ] = None ,
121+ pmap_compatible : Optional [bool ] = None ,
121122 ):
122123 # check the assert first as _DaliBaseIterator would run the prefetch
123124 if len (set (output_map )) != len (output_map ):
@@ -131,6 +132,10 @@ def __init__(
131132 ), "`sharding` should be an instance of `jax.sharding.Sharding`"
132133 self ._sharding = sharding
133134
135+ # When pmap_compatible is None (default), auto-infer: False for single-pipeline
136+ # iterators. _data_iterator_impl sets True automatically when devices are provided.
137+ self ._pmap_compatible = pmap_compatible if pmap_compatible is not None else False
138+
134139 assert (
135140 last_batch_policy != LastBatchPolicy .PARTIAL
136141 ), "JAX iterator does not support partial last batch policy."
@@ -170,7 +175,7 @@ def _next_impl(self):
170175 for category_id , category_name in enumerate (self .output_map ):
171176 category_outputs = self ._gather_outputs_for_category (pipelines_outputs , category_id )
172177
173- if self ._num_gpus == 1 and self ._sharding is None :
178+ if self ._num_gpus == 1 and self ._sharding is None and not self . _pmap_compatible :
174179 next_output [category_name ] = category_outputs [0 ]
175180 else :
176181 self ._assert_shards_shapes (category_outputs )
@@ -276,6 +281,7 @@ def _data_iterator_impl(
276281 prepare_first_batch : bool = True ,
277282 sharding : Optional [Sharding ] = None ,
278283 devices : Optional [List [jax .Device ]] = None ,
284+ pmap_compatible : Optional [bool ] = None ,
279285):
280286 """Implementation of the data_iterator decorator. It is extracted to a separate function
281287 to be reused by the peekable iterator decorator.
@@ -309,6 +315,7 @@ def create_iterator(*args, checkpoints=None, **wrapper_kwargs):
309315 last_batch_padded = last_batch_padded ,
310316 last_batch_policy = last_batch_policy ,
311317 prepare_first_batch = prepare_first_batch ,
318+ pmap_compatible = pmap_compatible ,
312319 )
313320 else :
314321 pipelines = []
@@ -363,8 +370,14 @@ def create_iterator(*args, checkpoints=None, **wrapper_kwargs):
363370 last_batch_policy = last_batch_policy ,
364371 prepare_first_batch = prepare_first_batch ,
365372 sharding = sharding ,
373+ pmap_compatible = pmap_compatible ,
366374 )
367375 elif devices is not None :
376+ # Auto-enable pmap_compatible when devices are provided, unless the user
377+ # explicitly overrode it.
378+ effective_pmap_compatible = (
379+ pmap_compatible if pmap_compatible is not None else True
380+ )
368381 return iterator_type (
369382 pipelines = pipelines ,
370383 output_map = output_map ,
@@ -374,6 +387,7 @@ def create_iterator(*args, checkpoints=None, **wrapper_kwargs):
374387 last_batch_padded = last_batch_padded ,
375388 last_batch_policy = last_batch_policy ,
376389 prepare_first_batch = prepare_first_batch ,
390+ pmap_compatible = effective_pmap_compatible ,
377391 )
378392
379393 raise AssertionError (
@@ -396,6 +410,7 @@ def data_iterator(
396410 prepare_first_batch : bool = True ,
397411 sharding : Optional [Sharding ] = None ,
398412 devices : Optional [List [jax .Device ]] = None ,
413+ pmap_compatible : Optional [bool ] = None ,
399414):
400415 """Decorator for DALI iterator for JAX. Decorated function when called returns DALI
401416 iterator for JAX.
@@ -471,6 +486,14 @@ def data_iterator(
471486 return outputs compatible with pmapped JAX functions.
472487 This argument is mutually exclusive with `sharding` argument. If `sharding`
473488 is provided, `devices` should be set to None.
489+ pmap_compatible : bool, optional, default = None
490+ Controls whether the iterator produces outputs with a leading device axis
491+ compatible with ``jax.pmap``. When ``None`` (default), it is inferred
492+ automatically: ``True`` when ``devices`` is provided, ``False`` otherwise.
493+ Set to ``True`` explicitly to force pmap-compatible output (shape
494+ ``[num_devices, batch_per_device, ...]``) without using the ``devices``
495+ argument. Set to ``False`` to suppress the device axis even when ``devices``
496+ is provided.
474497 checkpoints : list of str, optional, default = None
475498 Checkpoints obtained with `.checkpoints()` method of the iterator.
476499 If provided, they will be used to restore the state of the pipelines.
@@ -506,4 +529,5 @@ def data_iterator(
506529 prepare_first_batch ,
507530 sharding ,
508531 devices ,
532+ pmap_compatible ,
509533 )
0 commit comments