diff --git a/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu b/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu index 66ef1b3..8b06789 100644 --- a/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu +++ b/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu @@ -374,7 +374,7 @@ flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams p scale_for_old = 1.0f; new_max = mi; } else { - new_max = cur_max; + new_max = max(mi, cur_max); scale_for_old = exp2f(mi - new_max); }