@@ -579,10 +579,13 @@ def unpermute(
579579 if self .config .decoder_block == ctypes .DecoderBlockType .LLAMA4 :
580580 # For Llama4, combine using weights of 1 for selected experts
581581 reshaped_weights = jnp .ones_like (reshaped_weights )
582+ if self .config .float32_weight_sum :
583+ reshaped_intermediate = reshaped_intermediate .astype (jnp .float32 )
584+ reshaped_weights = reshaped_weights .astype (jnp .float32 )
582585 output = jnp .einsum (
583586 "BKE,BK -> BE" ,
584- reshaped_intermediate . astype ( jnp . float32 ) ,
585- reshaped_weights . astype ( jnp . float32 ) ,
587+ reshaped_intermediate ,
588+ reshaped_weights ,
586589 precision = matmul_precision ,
587590 )
588591 return output .reshape (batch_size , sequence_length , - 1 ).astype (self .dtype )
@@ -1670,14 +1673,17 @@ def dense_matmul(
16701673 if self .config .activations_in_float32 :
16711674 intermediate_layer = intermediate_layer .astype (jnp .float32 )
16721675 intermediate_layer = adc .checkpoint_name (intermediate_layer , "mlpwo" )
1673- with jax .named_scope ("w_sum " ):
1676+ with jax .named_scope ("weight_sum " ):
16741677 if is_llama4_decoder_layer :
16751678 weights = self .reshape_and_update_weights (jnp .ones_like (top_k_weights ), top_k_indices )
1679+ if self .config .float32_weight_sum :
1680+ intermediate_layer = intermediate_layer .astype (jnp .float32 )
1681+ weights = weights .astype (jnp .float32 )
16761682 # cast to f32 for sum up in einsum op
16771683 output = jnp .einsum (
16781684 "BSEM,BSE -> BSM" ,
1679- intermediate_layer . astype ( jnp . float32 ) ,
1680- weights . astype ( jnp . float32 ), # pylint: disable=undefined-variable,possibly-used-before-assignment
1685+ intermediate_layer ,
1686+ weights ,
16811687 precision = matmul_precision ,
16821688 ).astype (self .dtype )
16831689 return output , None
0 commit comments