@@ -624,6 +624,7 @@ def pmap(
624624 backend : str | None = None ,
625625 axis_size : int | None = None ,
626626 donate_argnums : int | tp .Iterable [int ] = (),
627+ global_arg_shapes : tuple [tuple [int , ...], ...] | None = None ,
627628 # nnx specific
628629 transform_metadata : tp .Mapping [str , tp .Any ] = FrozenDict ({}),
629630 graph : bool | None = None ,
@@ -644,6 +645,7 @@ def pmap(
644645 backend : str | None = None ,
645646 axis_size : int | None = None ,
646647 donate_argnums : int | tp .Iterable [int ] = (),
648+ global_arg_shapes : tuple [tuple [int , ...], ...] | None = None ,
647649 # nnx specific
648650 transform_metadata : tp .Mapping [str , tp .Any ] = FrozenDict ({}),
649651 graph : bool | None = None ,
@@ -663,6 +665,7 @@ def pmap(
663665 backend : str | None = None ,
664666 axis_size : int | None = None ,
665667 donate_argnums : int | tp .Iterable [int ] = (),
668+ global_arg_shapes : tuple [tuple [int , ...], ...] | None = None ,
666669 # nnx specific
667670 transform_metadata : tp .Mapping [str , tp .Any ] = FrozenDict ({}),
668671 graph : bool | None = None ,
@@ -746,6 +749,7 @@ def pmap(
746749 backend = backend ,
747750 axis_size = axis_size ,
748751 donate_argnums = donate_argnums ,
752+ global_arg_shapes = global_arg_shapes ,
749753 transform_metadata = transform_metadata ,
750754 graph = graph ,
751755 graph_updates = graph_updates ,
@@ -779,6 +783,7 @@ def pmap(
779783 backend = backend ,
780784 axis_size = axis_size ,
781785 donate_argnums = donate_argnums ,
786+ global_arg_shapes = global_arg_shapes ,
782787 )
783788
784789 @functools .wraps (f_unbound )
@@ -822,6 +827,7 @@ def simple_pmap_wrapper(*args, **kwargs):
822827 backend = backend ,
823828 axis_size = axis_size ,
824829 donate_argnums = donate_argnums ,
830+ global_arg_shapes = global_arg_shapes ,
825831 )
826832
827833 @functools .wraps (f )
0 commit comments