diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index 101d1c8e2..0c5732f80 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -624,6 +624,7 @@ def pmap( backend: str | None = None, axis_size: int | None = None, donate_argnums: int | tp.Iterable[int] = (), + global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, @@ -644,6 +645,7 @@ def pmap( backend: str | None = None, axis_size: int | None = None, donate_argnums: int | tp.Iterable[int] = (), + global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, @@ -663,6 +665,7 @@ def pmap( backend: str | None = None, axis_size: int | None = None, donate_argnums: int | tp.Iterable[int] = (), + global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, @@ -746,6 +749,7 @@ def pmap( backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes, transform_metadata=transform_metadata, graph=graph, graph_updates=graph_updates, @@ -779,6 +783,7 @@ def pmap( backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes, ) @functools.wraps(f_unbound) @@ -822,6 +827,7 @@ def simple_pmap_wrapper(*args, **kwargs): backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes, ) @functools.wraps(f)