@@ -390,6 +390,8 @@ def get_autotune_configs():
390390
391391autotune_configs , autotune_keys = get_autotune_configs ()
392392
393+ float8_info = torch .finfo (torch .float8_e4m3fnuz )
394+
393395
394396@triton .autotune (
395397 configs = autotune_configs ,
@@ -451,6 +453,8 @@ def attn_fwd(
451453 BIAS_TYPE : tl .constexpr ,
452454 ENABLE_DROPOUT : tl .constexpr ,
453455 RETURN_ENCODED_SOFTMAX : tl .constexpr ,
456+ FP8_MIN : tl .constexpr = float8_info .min ,
457+ FP8_MAX : tl .constexpr = float8_info .max ,
454458):
455459 start_m = tl .program_id (0 )
456460 off_h_q = tl .program_id (1 )
@@ -733,6 +737,7 @@ def attn_fwd(
733737 causal_start_idx = seqlen_q - seqlen_k
734738 if USE_FP8 :
735739 acc *= o_descale
740+ acc = tl .clamp (acc , FP8_MIN , FP8_MAX )
736741 acc = acc .to (Out .type .element_ty )
737742 if IS_CAUSAL : # noqa: SIM102
738743 if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx :
@@ -832,9 +837,9 @@ def forward(
832837
833838 def check_and_convert (t , scale ):
834839 if t .dtype != float8 :
835- finfo = torch .finfo (float8 )
836840 descale = 1.0 / scale
837- ts = (t * descale ).clamp (min = finfo .min , max = finfo .max )
841+ ts = (t * descale ).clamp (min = float8_info .min ,
842+ max = float8_info .max )
838843 return ts .to (float8 )
839844 else :
840845 return t
0 commit comments