Skip to content

Commit d983769

Browse files
authored
fix cuda graph (vllm-project#22721)
Signed-off-by: fsx950223 <fsx950223@outlook.com>
1 parent 8fd9209 commit d983769

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Attention layer with AiterFlashAttention."""
44
from dataclasses import dataclass
5-
from typing import ClassVar, Optional
5+
from typing import Optional
66

77
import torch
88

@@ -11,7 +11,8 @@
1111
from vllm.config import VllmConfig
1212
from vllm.logger import init_logger
1313
from vllm.platforms import current_platform
14-
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
14+
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
15+
AttentionMetadataBuilder,
1516
CommonAttentionMetadata)
1617
from vllm.v1.kv_cache_interface import AttentionSpec
1718

@@ -231,7 +232,7 @@ class AiterFlashAttentionMetadata:
231232

232233
class AiterFlashAttentionMetadataBuilder(
233234
AttentionMetadataBuilder[AiterFlashAttentionMetadata]):
234-
full_cudagraph_supported: ClassVar[bool] = True
235+
cudagraph_support = AttentionCGSupport.ALWAYS
235236

236237
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
237238
vllm_config: VllmConfig, device: torch.device):

0 commit comments

Comments
 (0)