Skip to content

Commit 6f35ece

Browse files
Merge pull request #2797 from AI-Hypercomputer:qinwen/check-in
PiperOrigin-RevId: 842443607
2 parents e02aa18 + eaffeef commit 6f35ece

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

src/MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,7 @@ logical_axis_rules: [
449449
['cache_kv', []],
450450
['cache_sequence', []],
451451
['exp', 'expert'],
452+
['exp_with_fsdp', 'fsdp'],
452453
['paged_kv_heads', ['tensor']],
453454
['num_pages', []],
454455
['tokens_per_page', []],

src/MaxText/layers/deepseek.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def self_attention_with_norm(
106106
mscale=cfg.mscale,
107107
rope_factor=cfg.rope_factor,
108108
model_mode=model_mode,
109+
attn_logits_soft_cap=cfg.attn_logits_soft_cap,
109110
)
110111

111112
attention_lnx, _ = attention_layer(

src/MaxText/layers/moe.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,11 +1261,11 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12611261

12621262
if self.config.moe_fsdp_use_two_stage_all_gather:
12631263
# Unshard on fsdp axis
1264-
w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp", "embed_tensor_transpose", "mlp"))
1265-
w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp", "embed_tensor_transpose", "mlp"))
1264+
w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp"))
1265+
w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp"))
12661266

12671267
# Unshard on fsdp_transpose axis
1268-
wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp", "mlp", "embed_tensor_transpose"))
1268+
wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "mlp", "embed_tensor_transpose"))
12691269

12701270
# Make sure XLA does not optimize by combining above All-Gather to unshard
12711271
# on FSDP axis and the subsequent unshard on fsdp_transpose axis
@@ -1274,9 +1274,9 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12741274
wo_kernel = jax.lax.optimization_barrier(wo_kernel)
12751275

12761276
# Unshard on both fsdp and fsdp_transpose transpose
1277-
w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp", "embed_tensor_transpose", "mlp_no_fsdp"))
1278-
w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp", "embed_tensor_transpose", "mlp_no_fsdp"))
1279-
wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp", "mlp_no_fsdp", "embed_tensor_transpose"))
1277+
w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp_no_fsdp"))
1278+
w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp_no_fsdp"))
1279+
wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "mlp_no_fsdp", "embed_tensor_transpose"))
12801280

12811281
if self.get_tensor_transpose_parallelism_size() > 1:
12821282
input_axes = (batch_logical_axis, "activation_norm_length", "activation_embed")

0 commit comments

Comments
 (0)