1616
1717import jax
1818import jax .numpy as jnp
19+ from jax .sharding import NamedSharding
20+
21+ from MaxText .common_types import ShardMode
22+ from MaxText .maxtext_utils import maybe_shard_with_name
1923
2024
2125def gradient_accumulation_loss_and_grad (
@@ -58,6 +62,17 @@ def gradient_accumulation_loss_and_grad(
5862 - final_aux (PyTree): Auxiliary outputs, summed across microbatches.
5963 - raw_grads (PyTree): The accumulated and averaged gradients.
6064 """
65+
66+ def _maybe_shard_with_name (inputs , sharding_names ):
67+ """Wrapper of maybe_shard_with_name with fixed shard_mode"""
68+ return maybe_shard_with_name (inputs , sharding_names , config .shard_mode )
69+
70+ # For more efficient DP/ZeRO-1 + GA
71+ if config .shard_mode == ShardMode .EXPLICIT and config .ici_data_parallelism > 1 :
72+ ga_params_shardings = jax .tree .map (update_sharding_for_reduced , params_shardings )
73+ grad_shardings = jax .tree .map (update_sharding_for_unreduced , params_shardings )
74+ else :
75+ ga_params_shardings = grad_shardings = params_shardings
6176 # When using Zero-1 optimizer sharding, cast params to lower precision and apply sharding constraints
6277 # so that all-gather is done once in the lower precision before the gradient accumulation loop
6378 if config .shard_optimizer_over_data :
@@ -68,15 +83,14 @@ def convert_to_bf16(param):
6883 return param
6984
7085 ga_params = jax .tree_util .tree_map (convert_to_bf16 , params )
71- ga_params = jax .tree .map (jax .lax .with_sharding_constraint , ga_params , params_shardings )
7286 else :
7387 ga_params = params
7488
89+ ga_params = jax .tree .map (_maybe_shard_with_name , ga_params , ga_params_shardings )
7590 grad_func = jax .value_and_grad (_loss_fn , argnums = 4 , has_aux = True )
7691
7792 def accumulate_gradient (acc_grad_and_loss , data ):
7893 ga_params = acc_grad_and_loss ["ga_params" ]
79-
8094 (_ , aux ), cur_batch_gradient = grad_func (model , config , data , dropout_rng , ga_params , * extra_dpo_args , is_train = True )
8195 acc_grad_and_loss ["loss" ] += aux ["total_loss" ]
8296 acc_grad_and_loss ["moe_lb_loss" ] += aux ["moe_lb_loss" ]
@@ -94,7 +108,7 @@ def reshape_to_microbatch_accumulations(batch_arr):
94108
95109 data = jax .tree_util .tree_map (reshape_to_microbatch_accumulations , data )
96110 init_grad = jax .tree_util .tree_map (jnp .zeros_like , ga_params )
97- init_grad = jax .tree .map (jax . lax . with_sharding_constraint , init_grad , params_shardings )
111+ init_grad = jax .tree .map (_maybe_shard_with_name , init_grad , grad_shardings )
98112 init_grad_and_loss = {
99113 "loss" : 0.0 ,
100114 "grad" : init_grad ,
@@ -113,9 +127,23 @@ def reshape_to_microbatch_accumulations(batch_arr):
113127 + grad_and_loss ["mtp_loss" ] / config .gradient_accumulation_steps
114128 )
115129 raw_grads = grad_and_loss ["grad" ]
116- if config .shard_optimizer_over_data :
117- raw_grads = jax .tree .map (jax .lax .with_sharding_constraint , raw_grads , params_shardings )
130+ raw_grads = jax .tree .map (_maybe_shard_with_name , raw_grads , params_shardings )
118131 raw_grads = jax .tree_util .tree_map (lambda arr : arr / grad_and_loss ["total_weights" ], raw_grads )
119132 aux = jax .tree .map (lambda x : jnp .sum (x , axis = 0 ), aux ) # pytype: disable=module-attr
120133
121134 return loss , aux , raw_grads
135+
136+
137+ # GA helper functions
138+ def update_sharding_for_reduced (sharding : NamedSharding ) -> NamedSharding :
139+ """
140+ Add reduced on data axis of given NamedSharding
141+ """
142+ return sharding .update (spec = sharding .spec .update (reduced = {"data" }))
143+
144+
145+ def update_sharding_for_unreduced (sharding : NamedSharding ) -> NamedSharding :
146+ """
147+ Add unreduced on data axis of given NamedSharding
148+ """
149+ return sharding .update (spec = sharding .spec .update (unreduced = {"data" }))
0 commit comments