@@ -1261,11 +1261,11 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12611261
12621262 if self .config .moe_fsdp_use_two_stage_all_gather :
12631263 # Unshard on fsdp axis
1264- w0_kernel = self ._maybe_shard_with_logical (w0_kernel , ("exp " , "embed_tensor_transpose" , "mlp" ))
1265- w1_kernel = self ._maybe_shard_with_logical (w1_kernel , ("exp " , "embed_tensor_transpose" , "mlp" ))
1264+ w0_kernel = self ._maybe_shard_with_logical (w0_kernel , ("exp_with_fsdp " , "embed_tensor_transpose" , "mlp" ))
1265+ w1_kernel = self ._maybe_shard_with_logical (w1_kernel , ("exp_with_fsdp " , "embed_tensor_transpose" , "mlp" ))
12661266
12671267 # Unshard on fsdp_transpose axis
1268- wo_kernel = self ._maybe_shard_with_logical (wo_kernel , ("exp " , "mlp" , "embed_tensor_transpose" ))
1268+ wo_kernel = self ._maybe_shard_with_logical (wo_kernel , ("exp_with_fsdp " , "mlp" , "embed_tensor_transpose" ))
12691269
12701270 # Make sure XLA does not optimize by combining above All-Gather to unshard
12711271 # on FSDP axis and the subsequent unshard on fsdp_transpose axis
@@ -1274,9 +1274,9 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12741274 wo_kernel = jax .lax .optimization_barrier (wo_kernel )
12751275
12761276 # Unshard on both fsdp and fsdp_transpose transpose
1277- w0_kernel = self ._maybe_shard_with_logical (w0_kernel , ("exp " , "embed_tensor_transpose" , "mlp_no_fsdp" ))
1278- w1_kernel = self ._maybe_shard_with_logical (w1_kernel , ("exp " , "embed_tensor_transpose" , "mlp_no_fsdp" ))
1279- wo_kernel = self ._maybe_shard_with_logical (wo_kernel , ("exp " , "mlp_no_fsdp" , "embed_tensor_transpose" ))
1277+ w0_kernel = self ._maybe_shard_with_logical (w0_kernel , ("exp_with_fsdp " , "embed_tensor_transpose" , "mlp_no_fsdp" ))
1278+ w1_kernel = self ._maybe_shard_with_logical (w1_kernel , ("exp_with_fsdp " , "embed_tensor_transpose" , "mlp_no_fsdp" ))
1279+ wo_kernel = self ._maybe_shard_with_logical (wo_kernel , ("exp_with_fsdp " , "mlp_no_fsdp" , "embed_tensor_transpose" ))
12801280
12811281 if self .get_tensor_transpose_parallelism_size () > 1 :
12821282 input_axes = (batch_logical_axis , "activation_norm_length" , "activation_embed" )
0 commit comments