2020
2121import flax .linen as nn
2222import jax
23- from jax .experimental import mesh_utils
2423import numpy as np
2524
2625
@@ -68,7 +67,7 @@ class DataParallelPartitioner(Partitioner):
6867 """Data parallel partitioner."""
6968
7069 def __init__ (self , data_axis : str = "batch" ):
71- self .mesh = jax .sharding . Mesh ( jax .devices ( ), (data_axis ,))
70+ self .mesh = jax .make_mesh (( jax .device_count (), ), (data_axis ,))
7271 self .data_sharding = jax .sharding .NamedSharding (
7372 self .mesh , jax .sharding .PartitionSpec (data_axis )
7473 )
@@ -109,6 +108,12 @@ def partition_init(
109108 self , init_fn : CreateStateFn , * , abstract_batch : PyTree | None = None
110109 ) -> CreateStateFn :
111110 with jax .sharding .use_mesh (self .mesh ):
111+ if abstract_batch is not None :
112+ abstract_state = jax .eval_shape (init_fn , abstract_batch )
113+ specs = nn .get_partition_spec (abstract_state )
114+ self .state_sharding = jax .tree .map (
115+ lambda x : jax .sharding .NamedSharding (self .mesh , x ), specs
116+ )
112117 init_fn = jax .jit (init_fn , out_shardings = self .state_sharding )
113118
114119 def _wrapped_init (batch : PyTree ) -> State :
@@ -145,12 +150,12 @@ class ModelParallelPartitioner(Partitioner):
145150 This only works with multi-controller Jax, i.e. communications along the ICI
146151 for TPUs. For scaling beyond a single TPU slice this needs to be extended to
147152 support Megascale XLA or single-controller Pathways. Consider using T5X, Pax,
148- or Gemax for these use cases.
153+ MaxText externally or Gemax internally for these use cases.
149154
150- Note: This assumes that all axes of the inputs except the final one are used
151- for data parallelism while the final one is used for model parallelism.
152- This tends to work well for 2D and 3D torus topologies since network latency
153- tends to be much higher for the leading axes .
155+ By default, all axes of the input are used for data parallelism. This results
156+ in fully-sharded data- parallelism for ND topologies or data-parallelism for 1D
157+ topologies. The range of axes can be configured using the `dp_axes` argument,
158+ i.e. axes[:dp_axes] will be used for data parallelism .
154159
155160 IMPORTANT: `shard_inputs` operates on a per process batch. This means that the
156161 input batch size on CPU must already be the per process batch size,
@@ -160,45 +165,49 @@ class ModelParallelPartitioner(Partitioner):
160165
161166 def __init__ (
162167 self ,
163- axes : Sequence [tuple [str , int ]],
168+ axes : Sequence [tuple [str , int ]] = (("batch" , - 1 ),),
169+ dp_axes : int | None = None ,
164170 rules : Mapping [str , str ] | None = None ,
165171 aot_compile : bool = False ,
166172 options : jax .stages .CompilerOptions | None = None ,
173+ devices : Sequence [jax .Device ] | None = None ,
167174 ):
168- if len (axes ) < 2 :
175+ if not axes :
176+ raise ValueError ("At least one axis must be specified in `axes`." )
177+ if dp_axes == 0 :
178+ raise ValueError (
179+ "Data parallelism axes range must be positive or negative."
180+ )
181+
182+ devices = devices if devices is not None else jax .devices ()
183+ axis_names = [axis for axis , _ in axes ]
184+ axis_sizes = [dim for _ , dim in axes ]
185+ if any (dim <= 0 for dim in axis_sizes [1 :]):
169186 raise ValueError (
170- "`axes` cannot less than 2D, use data-parallel "
171- f" partitioner instead . Got axes: { axes } ."
187+ "All dimensions except the first in the axes must be positive "
188+ f" integers . Got axes: { axes } ."
172189 )
190+ if axis_sizes [0 ] == - 1 :
191+ axis_sizes [0 ] = len (devices ) // math .prod (axis_sizes [1 :])
173192
174- mesh_devices = mesh_utils .create_device_mesh ([dim for _ , dim , in axes ])
175- self .mesh = jax .sharding .Mesh (mesh_devices , [axis for axis , _ in axes ])
193+ self .mesh = jax .make_mesh (axis_sizes , axis_names , devices = devices )
176194 self .rules = rules
177195 self .aot_compile = aot_compile
178196 self .options = options
179197
180- dp_axes , dp_dims = zip (* axes [:- 1 ])
181- _ , mp_dim = axes [- 1 ]
182-
183- if math .prod (dp_dims ) % jax .process_count () != 0 :
198+ dp_axis_names , dp_axis_sizes = zip (* axes [:dp_axes ])
199+ num_processes = jax .process_count ()
200+ if math .prod (dp_axis_sizes ) % num_processes != 0 :
184201 raise ValueError (
185202 "The data parallel dimensions in the mesh must be divisible by the"
186203 " number of processes as we assume data parallelism across"
187- f" processes. Got process count: { jax .process_count ()} and data"
188- f" parallelism dimensions: { dp_dims } for axes: { axes } and mesh"
189- f" devices: { self .mesh .devices } ."
190- )
191- if jax .local_device_count () % mp_dim != 0 :
192- raise ValueError (
193- "The number of local devices on each host must be divisible by the"
194- " model dimension as we assume model parallelism across local"
195- f" devices. Got local device count: { jax .local_device_count ()} and"
196- f" model parallelism dimension: { mp_dim } for axes: { axes } and mesh"
204+ f" processes. Got process count: { num_processes } and data"
205+ f" parallelism dimensions: { dp_axis_sizes } for axes: { axes } and mesh"
197206 f" devices: { self .mesh .devices } ."
198207 )
199208
200209 self .data_sharding = jax .sharding .NamedSharding (
201- self .mesh , jax .sharding .PartitionSpec (dp_axes )
210+ self .mesh , jax .sharding .PartitionSpec (dp_axis_names )
202211 )
203212 self .state_sharding = None
204213 self .abstract_batch = None
0 commit comments