Skip to content

Commit bbbd696

Browse files
[torch.compile][CI] Add back attn fusion on hopper/ada (vllm-project#32940)
Signed-off-by: Luka Govedič <[email protected]>
1 parent 9b77bb7 commit bbbd696

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

tests/compile/test_fusion_attn.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
is_blackwell,
1919
run_model,
2020
)
21-
from tests.utils import cuda_device_count_stateless, flat_product
21+
from tests.utils import flat_product
2222
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
2323
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
2424
from vllm.attention.layer import Attention
@@ -265,13 +265,13 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
265265
HEADS = [(64, 8), (40, 8)]
266266
PATTERN_TEST_MODELS_FP8 = [
267267
(
268-
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
268+
"RedHatAI/Meta-Llama-3.1-8B-FP8",
269269
TestAttentionFp8StaticQuantPatternModel,
270270
)
271271
]
272272
PATTERN_TEST_MODELS_FP4 = [
273273
(
274-
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
274+
"nvidia/Llama-3.1-8B-Instruct-NVFP4",
275275
TestAttentionNvfp4QuantPatternModel,
276276
)
277277
]
@@ -331,9 +331,8 @@ def test_attention_quant_pattern(
331331
if backend == AttentionBackendEnum.FLASHINFER and (
332332
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
333333
):
334+
# This also captures the FP4 case
334335
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
335-
if "Llama-4-Scout" in model_name and cuda_device_count_stateless() < 2:
336-
pytest.skip("Llama-4-Scout requires at least 2 GPUs")
337336

338337
custom_ops_list = custom_ops.split(",") if custom_ops else []
339338

0 commit comments

Comments
 (0)