Skip to content

Commit ad3f54e

Browse files
hawkinspFlax Authors
authored andcommitted
Remove spmd_lowering from MeshComputation and do more cleanups after pmap, PmapSharding is gone
PiperOrigin-RevId: 886581793
1 parent ee6d3c5 commit ad3f54e

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

flax/nnx/transforms/iteration.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,7 @@ def pmap(
624624
backend: str | None = None,
625625
axis_size: int | None = None,
626626
donate_argnums: int | tp.Iterable[int] = (),
627+
global_arg_shapes: tuple[tuple[int, ...], ...] | None = None,
627628
# nnx specific
628629
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
629630
graph: bool | None = None,
@@ -644,6 +645,7 @@ def pmap(
644645
backend: str | None = None,
645646
axis_size: int | None = None,
646647
donate_argnums: int | tp.Iterable[int] = (),
648+
global_arg_shapes: tuple[tuple[int, ...], ...] | None = None,
647649
# nnx specific
648650
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
649651
graph: bool | None = None,
@@ -663,6 +665,7 @@ def pmap(
663665
backend: str | None = None,
664666
axis_size: int | None = None,
665667
donate_argnums: int | tp.Iterable[int] = (),
668+
global_arg_shapes: tuple[tuple[int, ...], ...] | None = None,
666669
# nnx specific
667670
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
668671
graph: bool | None = None,
@@ -746,6 +749,7 @@ def pmap(
746749
backend=backend,
747750
axis_size=axis_size,
748751
donate_argnums=donate_argnums,
752+
global_arg_shapes=global_arg_shapes,
749753
transform_metadata=transform_metadata,
750754
graph=graph,
751755
graph_updates=graph_updates,
@@ -779,6 +783,7 @@ def pmap(
779783
backend=backend,
780784
axis_size=axis_size,
781785
donate_argnums=donate_argnums,
786+
global_arg_shapes=global_arg_shapes,
782787
)
783788

784789
@functools.wraps(f_unbound)
@@ -822,6 +827,7 @@ def simple_pmap_wrapper(*args, **kwargs):
822827
backend=backend,
823828
axis_size=axis_size,
824829
donate_argnums=donate_argnums,
830+
global_arg_shapes=global_arg_shapes,
825831
)
826832

827833
@functools.wraps(f)

0 commit comments

Comments
 (0)