Skip to content

Commit 3d215cd

Browse files
Merge pull request #2519 from AI-Hypercomputer:qinwen/add_up_quantize_config
PiperOrigin-RevId: 827554776
2 parents d2f608d + ee4e8cc commit 3d215cd

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

src/MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ logits_dot_in_fp32: False # whether to use fp32 in logits_dense or shared_embed
149149
cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher precision is generally beneficial, but it can vary slightly.
150150
float32_qk_product: False # in dot_product attention, whether to cast to fp32 the inputs to qk product
151151
float32_logits: False # in dot_product attention, whether to cast to fp32 the inputs to softmax
152+
float32_weight_sum: True # whether to use full fp32 precision for weight_sum during final unpermute in moe
152153

153154
# Multi-Token Prediction Configs
154155
# The number of auxiliary prediction layers to use for MTP.

src/MaxText/layers/moe.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -580,10 +580,13 @@ def unpermute(
580580
if self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4:
581581
# For Llama4, combine using weights of 1 for selected experts
582582
reshaped_weights = jnp.ones_like(reshaped_weights)
583+
if self.config.float32_weight_sum:
584+
reshaped_intermediate = reshaped_intermediate.astype(jnp.float32)
585+
reshaped_weights = reshaped_weights.astype(jnp.float32)
583586
output = jnp.einsum(
584587
"BKE,BK -> BE",
585-
reshaped_intermediate.astype(jnp.float32),
586-
reshaped_weights.astype(jnp.float32),
588+
reshaped_intermediate,
589+
reshaped_weights,
587590
precision=matmul_precision,
588591
)
589592
return output.reshape(batch_size, sequence_length, -1).astype(self.dtype)
@@ -1681,14 +1684,17 @@ def dense_matmul(
16811684
if self.config.activations_in_float32:
16821685
intermediate_layer = intermediate_layer.astype(jnp.float32)
16831686
intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo")
1684-
with jax.named_scope("w_sum"):
1687+
with jax.named_scope("weight_sum"):
16851688
if is_llama4_decoder_layer:
16861689
weights = self.reshape_and_update_weights(jnp.ones_like(top_k_weights), top_k_indices)
1690+
if self.config.float32_weight_sum:
1691+
intermediate_layer = intermediate_layer.astype(jnp.float32)
1692+
weights = weights.astype(jnp.float32)
16871693
# cast to f32 for sum up in einsum op
16881694
output = jnp.einsum(
16891695
"BSEM,BSE -> BSM",
1690-
intermediate_layer.astype(jnp.float32),
1691-
weights.astype(jnp.float32), # pylint: disable=undefined-variable,possibly-used-before-assignment
1696+
intermediate_layer,
1697+
weights,
16921698
precision=matmul_precision,
16931699
).astype(self.dtype)
16941700
return output, None

0 commit comments

Comments
 (0)