From fde3b2e87db13676085e4b7d6400fb4d55289769 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 19 Mar 2026 23:07:26 -0700 Subject: [PATCH] Remove spmd_lowering from MeshComputation and do more cleanups after pmap, PmapSharding is gone PiperOrigin-RevId: 886581793 --- flax/nnx/transforms/iteration.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index 101d1c8e2..0c5732f80 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -624,6 +624,7 @@ def pmap( backend: str | None = None, axis_size: int | None = None, donate_argnums: int | tp.Iterable[int] = (), + global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, @@ -644,6 +645,7 @@ def pmap( backend: str | None = None, axis_size: int | None = None, donate_argnums: int | tp.Iterable[int] = (), + global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, @@ -663,6 +665,7 @@ def pmap( backend: str | None = None, axis_size: int | None = None, donate_argnums: int | tp.Iterable[int] = (), + global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, @@ -746,6 +749,7 @@ def pmap( backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes, transform_metadata=transform_metadata, graph=graph, graph_updates=graph_updates, @@ -779,6 +783,7 @@ def pmap( backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes, ) @functools.wraps(f_unbound) @@ -822,6 +827,7 @@ def simple_pmap_wrapper(*args, **kwargs): backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes, ) @functools.wraps(f)