Skip to content

Commit eaffeef

Browse files
committed
support shard fsdp in 2d-all gather in moe and support capped attention as option in deepseek
1 parent 5489efd commit eaffeef

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

src/MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,7 @@ logical_axis_rules: [
449449
['cache_kv', []],
450450
['cache_sequence', []],
451451
['exp', 'expert'],
452+
['exp_with_fsdp', 'fsdp'],
452453
['paged_kv_heads', ['tensor']],
453454
['num_pages', []],
454455
['tokens_per_page', []],

src/MaxText/layers/deepseek.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def self_attention_with_norm(
106106
mscale=cfg.mscale,
107107
rope_factor=cfg.rope_factor,
108108
model_mode=model_mode,
109+
attn_logits_soft_cap=cfg.attn_logits_soft_cap,
109110
)
110111

111112
attention_lnx, _ = attention_layer(

src/MaxText/layers/moe.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,11 +1256,11 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12561256

12571257
if self.config.moe_fsdp_use_two_stage_all_gather:
12581258
# Unshard on fsdp axis
1259-
w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp", "embed_tensor_transpose", "mlp"))
1260-
w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp", "embed_tensor_transpose", "mlp"))
1259+
w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp"))
1260+
w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp"))
12611261

12621262
# Unshard on fsdp_transpose axis
1263-
wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp", "mlp", "embed_tensor_transpose"))
1263+
wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "mlp", "embed_tensor_transpose"))
12641264

12651265
# Make sure XLA does not optimize by combining above All-Gather to unshard
12661266
# on FSDP axis and the subsequent unshard on fsdp_transpose axis
@@ -1269,9 +1269,9 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12691269
wo_kernel = jax.lax.optimization_barrier(wo_kernel)
12701270

12711271
# Unshard on both fsdp and fsdp_transpose transpose
1272-
w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp", "embed_tensor_transpose", "mlp_no_fsdp"))
1273-
w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp", "embed_tensor_transpose", "mlp_no_fsdp"))
1274-
wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp", "mlp_no_fsdp", "embed_tensor_transpose"))
1272+
w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp_no_fsdp"))
1273+
w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp_no_fsdp"))
1274+
wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "mlp_no_fsdp", "embed_tensor_transpose"))
12751275

12761276
if self.get_tensor_transpose_parallelism_size() > 1:
12771277
input_axes = (batch_logical_axis, "activation_norm_length", "activation_embed")

0 commit comments

Comments
 (0)