Skip to content

Commit 043c93d

Browse files
committed
Trying to make scales work with compileable attention
1 parent 079750e commit 043c93d

File tree

12 files changed

+50
-42
lines changed

12 files changed

+50
-42
lines changed

vllm/attention/backends/abstract.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,9 @@ def forward(
251251
attn_metadata: T,
252252
k_scale: torch.Tensor,
253253
v_scale: torch.Tensor,
254+
q_scale: Optional[torch.Tensor] = None,
255+
prob_scale: Optional[torch.Tensor] = None,
256+
fp8_out_scale: Optional[torch.Tensor] = None,
254257
output: Optional[torch.Tensor] = None,
255-
fp8_comp_scales: List[Optional[torch.Tensor]] = None,
256258
) -> torch.Tensor:
257259
raise NotImplementedError

vllm/attention/backends/blocksparse_attn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,10 @@ def forward(
368368
attn_metadata: BlocksparseFlashAttentionMetadata,
369369
k_scale: torch.Tensor,
370370
v_scale: torch.Tensor,
371+
q_scale: Optional[torch.Tensor] = None,
372+
prob_scale: Optional[torch.Tensor] = None,
373+
fp8_out_scale: Optional[torch.Tensor] = None,
371374
output: Optional[torch.Tensor] = None,
372-
fp8_comp_scales: List[Optional[torch.Tensor]] = None,
373375
) -> torch.Tensor:
374376
"""Forward pass with FlashAttention and PagedAttention.
375377

vllm/attention/backends/flash_attn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,8 +642,10 @@ def forward(
642642
attn_metadata: FlashAttentionMetadata,
643643
k_scale: float = 1.0,
644644
v_scale: float = 1.0,
645+
q_scale: Optional[torch.Tensor] = None,
646+
prob_scale: Optional[torch.Tensor] = None,
647+
fp8_out_scale: Optional[torch.Tensor] = None,
645648
output: Optional[torch.Tensor] = None,
646-
fp8_comp_scales: List[Optional[torch.Tensor]] = None,
647649
) -> torch.Tensor:
648650
"""Forward pass with FlashAttention.
649651

vllm/attention/backends/flashinfer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -782,8 +782,10 @@ def forward(
782782
attn_metadata: FlashInferMetadata,
783783
k_scale: float = 1.0,
784784
v_scale: float = 1.0,
785+
q_scale: Optional[torch.Tensor] = None,
786+
prob_scale: Optional[torch.Tensor] = None,
787+
fp8_out_scale: Optional[torch.Tensor] = None,
785788
output: Optional[torch.Tensor] = None,
786-
fp8_comp_scales: List[Optional[torch.Tensor]] = None,
787789
) -> torch.Tensor:
788790

789791
# TODO: directly write to output tensor

vllm/attention/backends/hpu_attn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,10 @@ def forward(
159159
attn_metadata: HPUAttentionMetadata,
160160
k_scale: float = 1.0,
161161
v_scale: float = 1.0,
162+
q_scale: Optional[torch.Tensor] = None,
163+
prob_scale: Optional[torch.Tensor] = None,
164+
fp8_out_scale: Optional[torch.Tensor] = None,
162165
output: Optional[torch.Tensor] = None,
163-
fp8_comp_scales: List[Optional[torch.Tensor]] = None,
164166
) -> torch.Tensor:
165167
"""Forward pass with xFormers and PagedAttention.
166168

vllm/attention/backends/ipex_attn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,10 @@ def forward(
178178
attn_metadata: IpexAttnMetadata, # type: ignore
179179
k_scale: float = 1.0,
180180
v_scale: float = 1.0,
181+
q_scale: Optional[torch.Tensor] = None,
182+
prob_scale: Optional[torch.Tensor] = None,
183+
fp8_out_scale: Optional[torch.Tensor] = None,
181184
output: Optional[torch.Tensor] = None,
182-
fp8_comp_scales: List[Optional[torch.Tensor]] = None,
183185
) -> torch.Tensor:
184186
"""Forward pass with IPEX varlen_attention and PagedAttention.
185187

vllm/attention/backends/pallas.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,10 @@ def forward(
157157
attn_metadata: PallasMetadata,
158158
k_scale: float = 1.0,
159159
v_scale: float = 1.0,
160+
q_scale: Optional[torch.Tensor] = None,
161+
prob_scale: Optional[torch.Tensor] = None,
162+
fp8_out_scale: Optional[torch.Tensor] = None,
160163
output: Optional[torch.Tensor] = None,
161-
fp8_comp_scales: List[Optional[torch.Tensor]] = None,
162164
) -> torch.Tensor:
163165
"""Forward pass with Pallas attention.
164166

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -550,8 +550,10 @@ def forward(
550550
attn_metadata: ROCmFlashAttentionMetadata,
551551
k_scale: torch.Tensor,
552552
v_scale: torch.Tensor,
553+
q_scale: Optional[torch.Tensor] = None,
554+
prob_scale: Optional[torch.Tensor] = None,
555+
fp8_out_scale: Optional[torch.Tensor] = None,
553556
output: Optional[torch.Tensor] = None,
554-
fp8_comp_scales: List[Optional[torch.Tensor]] = None,
555557
) -> torch.Tensor:
556558
"""Forward pass with FlashAttention and PagedAttention.
557559
@@ -601,9 +603,6 @@ def forward(
601603
Returns:
602604
shape = [num_tokens, num_heads * head_size]
603605
"""
604-
q_scale, prob_scale, fp8_out_scale = fp8_comp_scales or [None, None,
605-
None]
606-
607606
query = query.view(-1, self.num_heads, self.head_size)
608607
if key is not None:
609608
assert value is not None
@@ -687,7 +686,7 @@ def forward(
687686
1.0 / q_scale.item(), 1.0 / k_scale.item(),
688687
1.0 / v_scale.item(), 1.0 / prob_scale.item(),
689688
fp8_out_scale.item()) if (
690-
fp8_out_scale
689+
fp8_out_scale and q_scale and prob_scale
691690
and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN) else None
692691
out, _ = self.attn_func(
693692
query,

vllm/attention/backends/torch_sdpa.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,8 +437,10 @@ def forward(
437437
attn_metadata: TorchSDPAMetadata, # type: ignore
438438
k_scale: float = 1.0,
439439
v_scale: float = 1.0,
440+
q_scale: Optional[torch.Tensor] = None,
441+
prob_scale: Optional[torch.Tensor] = None,
442+
fp8_out_scale: Optional[torch.Tensor] = None,
440443
output: Optional[torch.Tensor] = None,
441-
fp8_comp_scales: List[Optional[torch.Tensor]] = None,
442444
) -> torch.Tensor:
443445
"""Forward pass with torch SDPA and PagedAttention.
444446

vllm/attention/backends/xformers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,8 +421,10 @@ def forward(
421421
attn_metadata: "XFormersMetadata",
422422
k_scale: float = 1.0,
423423
v_scale: float = 1.0,
424+
q_scale: Optional[torch.Tensor] = None,
425+
prob_scale: Optional[torch.Tensor] = None,
426+
fp8_out_scale: Optional[torch.Tensor] = None,
424427
output: Optional[torch.Tensor] = None,
425-
fp8_comp_scales: List[Optional[torch.Tensor]] = None,
426428
) -> torch.Tensor:
427429
"""Forward pass with xFormers and PagedAttention.
428430

0 commit comments

Comments
 (0)