Skip to content

Commit a600e9f

Browse files
authored
FP8 FA fixes (ROCm#381)
* FP8 FA fixes Summary: Add missing clamp and fix reciprocal scale computation. * linter
1 parent b5839a1 commit a600e9f

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -681,9 +681,8 @@ def forward(
681681
seq_lens,
682682
make_attn_mask=False) # type: ignore
683683
full_scales = (
684-
1.0 / layer._q_scale.item(),
685-
1.0 / layer._k_scale.item(), 1.0 /
686-
layer._v_scale.item(), 1.0 / layer._prob_scale.item(),
684+
layer._q_scale.item(), layer._k_scale.item(),
685+
layer._v_scale.item(), layer._prob_scale.item(),
687686
fp8_out_scale.item()) if (
688687
fp8_out_scale and layer._q_scale
689688
and layer._prob_scale

vllm/attention/ops/triton_flash_attention.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,8 @@ def get_autotune_configs():
390390

391391
autotune_configs, autotune_keys = get_autotune_configs()
392392

393+
float8_info = torch.finfo(torch.float8_e4m3fnuz)
394+
393395

394396
@triton.autotune(
395397
configs=autotune_configs,
@@ -451,6 +453,8 @@ def attn_fwd(
451453
BIAS_TYPE: tl.constexpr,
452454
ENABLE_DROPOUT: tl.constexpr,
453455
RETURN_ENCODED_SOFTMAX: tl.constexpr,
456+
FP8_MIN: tl.constexpr = float8_info.min,
457+
FP8_MAX: tl.constexpr = float8_info.max,
454458
):
455459
start_m = tl.program_id(0)
456460
off_h_q = tl.program_id(1)
@@ -733,6 +737,7 @@ def attn_fwd(
733737
causal_start_idx = seqlen_q - seqlen_k
734738
if USE_FP8:
735739
acc *= o_descale
740+
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
736741
acc = acc.to(Out.type.element_ty)
737742
if IS_CAUSAL: # noqa: SIM102
738743
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
@@ -832,9 +837,9 @@ def forward(
832837

833838
def check_and_convert(t, scale):
834839
if t.dtype != float8:
835-
finfo = torch.finfo(float8)
836840
descale = 1.0 / scale
837-
ts = (t * descale).clamp(min=finfo.min, max=finfo.max)
841+
ts = (t * descale).clamp(min=float8_info.min,
842+
max=float8_info.max)
838843
return ts.to(float8)
839844
else:
840845
return t

0 commit comments

Comments
 (0)