@@ -580,10 +580,13 @@ def unpermute(
580580 if self .config .decoder_block == ctypes .DecoderBlockType .LLAMA4 :
581581 # For Llama4, combine using weights of 1 for selected experts
582582 reshaped_weights = jnp .ones_like (reshaped_weights )
583+ if self .config .float32_weight_sum :
584+ reshaped_intermediate = reshaped_intermediate .astype (jnp .float32 )
585+ reshaped_weights = reshaped_weights .astype (jnp .float32 )
583586 output = jnp .einsum (
584587 "BKE,BK -> BE" ,
585- reshaped_intermediate . astype ( jnp . float32 ) ,
586- reshaped_weights . astype ( jnp . float32 ) ,
588+ reshaped_intermediate ,
589+ reshaped_weights ,
587590 precision = matmul_precision ,
588591 )
589592 return output .reshape (batch_size , sequence_length , - 1 ).astype (self .dtype )
@@ -1681,14 +1684,17 @@ def dense_matmul(
16811684 if self .config .activations_in_float32 :
16821685 intermediate_layer = intermediate_layer .astype (jnp .float32 )
16831686 intermediate_layer = adc .checkpoint_name (intermediate_layer , "mlpwo" )
1684- with jax .named_scope ("w_sum " ):
1687+ with jax .named_scope ("weight_sum " ):
16851688 if is_llama4_decoder_layer :
16861689 weights = self .reshape_and_update_weights (jnp .ones_like (top_k_weights ), top_k_indices )
1690+ if self .config .float32_weight_sum :
1691+ intermediate_layer = intermediate_layer .astype (jnp .float32 )
1692+ weights = weights .astype (jnp .float32 )
16871693 # cast to f32 for sum up in einsum op
16881694 output = jnp .einsum (
16891695 "BSEM,BSE -> BSM" ,
1690- intermediate_layer . astype ( jnp . float32 ) ,
1691- weights . astype ( jnp . float32 ), # pylint: disable=undefined-variable,possibly-used-before-assignment
1696+ intermediate_layer ,
1697+ weights ,
16921698 precision = matmul_precision ,
16931699 ).astype (self .dtype )
16941700 return output , None
0 commit comments