Skip to content

Commit ee4e8cc

Browse files
committed
add weight_sum_fp32 config
update update Update base.yml Update moe.py Update moe.py
1 parent cc9a196 commit ee4e8cc

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

src/MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ logits_dot_in_fp32: False # whether to use fp32 in logits_dense or shared_embed
149149
cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher precision is generally beneficial, but it can vary slightly.
150150
float32_qk_product: False # in dot_product attention, whether to cast to fp32 the inputs to qk product
151151
float32_logits: False # in dot_product attention, whether to cast to fp32 the inputs to softmax
152+
float32_weight_sum: True # whether to use full fp32 precision for weight_sum during final unpermute in moe
152153

153154
# Multi-Token Prediction Configs
154155
# The number of auxiliary prediction layers to use for MTP.

src/MaxText/layers/moe.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -579,10 +579,13 @@ def unpermute(
579579
if self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4:
580580
# For Llama4, combine using weights of 1 for selected experts
581581
reshaped_weights = jnp.ones_like(reshaped_weights)
582+
if self.config.float32_weight_sum:
583+
reshaped_intermediate = reshaped_intermediate.astype(jnp.float32)
584+
reshaped_weights = reshaped_weights.astype(jnp.float32)
582585
output = jnp.einsum(
583586
"BKE,BK -> BE",
584-
reshaped_intermediate.astype(jnp.float32),
585-
reshaped_weights.astype(jnp.float32),
587+
reshaped_intermediate,
588+
reshaped_weights,
586589
precision=matmul_precision,
587590
)
588591
return output.reshape(batch_size, sequence_length, -1).astype(self.dtype)
@@ -1670,14 +1673,17 @@ def dense_matmul(
16701673
if self.config.activations_in_float32:
16711674
intermediate_layer = intermediate_layer.astype(jnp.float32)
16721675
intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo")
1673-
with jax.named_scope("w_sum"):
1676+
with jax.named_scope("weight_sum"):
16741677
if is_llama4_decoder_layer:
16751678
weights = self.reshape_and_update_weights(jnp.ones_like(top_k_weights), top_k_indices)
1679+
if self.config.float32_weight_sum:
1680+
intermediate_layer = intermediate_layer.astype(jnp.float32)
1681+
weights = weights.astype(jnp.float32)
16761682
# cast to f32 for sum up in einsum op
16771683
output = jnp.einsum(
16781684
"BSEM,BSE -> BSM",
1679-
intermediate_layer.astype(jnp.float32),
1680-
weights.astype(jnp.float32), # pylint: disable=undefined-variable,possibly-used-before-assignment
1685+
intermediate_layer,
1686+
weights,
16811687
precision=matmul_precision,
16821688
).astype(self.dtype)
16831689
return output, None

0 commit comments

Comments
 (0)