@@ -1256,11 +1256,11 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12561256
12571257 if self .config .moe_fsdp_use_two_stage_all_gather :
12581258 # Unshard on fsdp axis
1259- w0_kernel = self ._maybe_shard_with_logical (w0_kernel , ("exp " , "embed_tensor_transpose" , "mlp" ))
1260- w1_kernel = self ._maybe_shard_with_logical (w1_kernel , ("exp " , "embed_tensor_transpose" , "mlp" ))
1259+ w0_kernel = self ._maybe_shard_with_logical (w0_kernel , ("exp_with_fsdp " , "embed_tensor_transpose" , "mlp" ))
1260+ w1_kernel = self ._maybe_shard_with_logical (w1_kernel , ("exp_with_fsdp " , "embed_tensor_transpose" , "mlp" ))
12611261
12621262 # Unshard on fsdp_transpose axis
1263- wo_kernel = self ._maybe_shard_with_logical (wo_kernel , ("exp " , "mlp" , "embed_tensor_transpose" ))
1263+ wo_kernel = self ._maybe_shard_with_logical (wo_kernel , ("exp_with_fsdp " , "mlp" , "embed_tensor_transpose" ))
12641264
12651265 # Make sure XLA does not optimize by combining above All-Gather to unshard
12661266 # on FSDP axis and the subsequent unshard on fsdp_transpose axis
@@ -1269,9 +1269,9 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12691269 wo_kernel = jax .lax .optimization_barrier (wo_kernel )
12701270
12711271 # Unshard on both fsdp and fsdp_transpose transpose
1272- w0_kernel = self ._maybe_shard_with_logical (w0_kernel , ("exp " , "embed_tensor_transpose" , "mlp_no_fsdp" ))
1273- w1_kernel = self ._maybe_shard_with_logical (w1_kernel , ("exp " , "embed_tensor_transpose" , "mlp_no_fsdp" ))
1274- wo_kernel = self ._maybe_shard_with_logical (wo_kernel , ("exp " , "mlp_no_fsdp" , "embed_tensor_transpose" ))
1272+ w0_kernel = self ._maybe_shard_with_logical (w0_kernel , ("exp_with_fsdp " , "embed_tensor_transpose" , "mlp_no_fsdp" ))
1273+ w1_kernel = self ._maybe_shard_with_logical (w1_kernel , ("exp_with_fsdp " , "embed_tensor_transpose" , "mlp_no_fsdp" ))
1274+ wo_kernel = self ._maybe_shard_with_logical (wo_kernel , ("exp_with_fsdp " , "mlp_no_fsdp" , "embed_tensor_transpose" ))
12751275
12761276 if self .get_tensor_transpose_parallelism_size () > 1 :
12771277 input_axes = (batch_logical_axis , "activation_norm_length" , "activation_embed" )
0 commit comments