Skip to content

Commit 16ded21

Browse files
authored
[XPU] support Triton Attention backend on Intel GPU (vllm-project#24149)
Signed-off-by: Kunshang Ji <[email protected]>
1 parent 2b30afa commit 16ded21

File tree

5 files changed

+49
-15
lines changed

5 files changed

+49
-15
lines changed

.buildkite/scripts/hardware_ci/run-xpu-test.sh

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@ docker run \
3030
bash -c '
3131
set -e
3232
echo $ZE_AFFINITY_MASK
33-
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
34-
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE
35-
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
36-
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
33+
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
34+
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE
35+
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
36+
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
37+
VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
3738
cd tests
3839
pytest -v -s v1/core
3940
pytest -v -s v1/engine

vllm/_ipex_ops.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,9 @@ def reshape_and_cache_flash(
242242
k_scale_float: float = 1.0,
243243
v_scale_float: float = 1.0,
244244
) -> None:
245-
assert kv_cache_dtype == "auto"
246-
# TODO: support FP8 kv cache.
247245
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
248-
key, value, key_cache, value_cache, slot_mapping)
246+
key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
247+
k_scale_float, v_scale_float)
249248

250249
@staticmethod
251250
def flash_attn_varlen_func(

vllm/attention/ops/paged_attn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@
66

77
import torch
88

9-
from vllm import _custom_ops as ops
9+
from vllm.platforms import current_platform
1010
from vllm.triton_utils import HAS_TRITON
1111

12+
if current_platform.is_cuda_alike():
13+
from vllm import _custom_ops as ops
14+
elif current_platform.is_xpu():
15+
from vllm._ipex_ops import ipex_ops as ops
16+
1217
if HAS_TRITON:
1318
from vllm.attention.ops.prefix_prefill import context_attention_fwd
1419

vllm/platforms/xpu.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,38 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
3737
dtype: torch.dtype, kv_cache_dtype: Optional[str],
3838
block_size: int, use_v1: bool, use_mla: bool,
3939
has_sink: bool) -> str:
40-
if selected_backend is not None and selected_backend != _Backend.IPEX:
41-
logger.info("Cannot use %s backend on XPU.", selected_backend)
4240
use_v1 = envs.VLLM_USE_V1
4341
if not use_v1:
4442
raise ValueError("XPU backend only supports V1.")
43+
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
44+
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
45+
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
46+
logger.info_once("Using Triton backend on V1 engine.")
47+
return TRITON_ATTN_VLLM_V1
48+
elif selected_backend == _Backend.FLASH_ATTN:
49+
logger.info_once("Using Flash Attention backend on V1 engine.")
50+
return FLASH_ATTN_V1
51+
elif selected_backend:
52+
raise ValueError(
53+
f"Invalid attention backend for {cls.device_name}, "
54+
f"with use_v1: {use_v1} use_mla: {use_mla}")
55+
4556
logger.info("Using Flash Attention backend on V1 engine.")
4657
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
4758

59+
@classmethod
60+
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
61+
model_config: "ModelConfig") -> bool:
62+
"""
63+
Check if the kv_cache_dtype is supported.
64+
XPU only support fp8 kv cache with triton backend.
65+
"""
66+
if envs.is_set("VLLM_ATTENTION_BACKEND") and \
67+
envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN_VLLM_V1":
68+
return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"]
69+
70+
return False
71+
4872
@classmethod
4973
def set_device(cls, device: torch.device) -> None:
5074
"""

vllm/v1/attention/backends/triton_attn.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import torch
99

10-
from vllm import _custom_ops as ops
1110
from vllm import envs
1211
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1312
AttentionMetadata, AttentionType)
@@ -23,6 +22,11 @@
2322
CommonAttentionMetadata)
2423
from vllm.v1.kv_cache_interface import AttentionSpec
2524

25+
if current_platform.is_cuda_alike():
26+
from vllm import _custom_ops as ops
27+
elif current_platform.is_xpu():
28+
from vllm._ipex_ops import ipex_ops as ops
29+
2630
logger = init_logger(__name__)
2731

2832

@@ -337,7 +341,7 @@ def forward(
337341
layer._v_scale,
338342
)
339343
else:
340-
torch.ops._C_cache_ops.reshape_and_cache_flash(
344+
ops.reshape_and_cache_flash(
341345
key,
342346
value,
343347
key_cache,
@@ -354,9 +358,10 @@ def forward(
354358
num_tokens, num_heads, head_size = query.shape
355359
assert layer._q_scale == 1.0, \
356360
"A non 1.0 q_scale is not currently supported."
357-
if not current_platform.is_rocm():
358-
# Skip Q quantization on ROCm, since dequantizing back to
359-
# f32 in the attention kernel is not supported.
361+
if current_platform.is_cuda():
362+
# Skip Q quantization on ROCm and XPU, enable this on cuda
363+
# only, since dequantizing back to f32 in the attention kernel
364+
# is not supported.
360365
query, _ = ops.scaled_fp8_quant(
361366
query.reshape(
362367
(num_tokens, num_heads * head_size)).contiguous(),

0 commit comments

Comments
 (0)