@@ -37,14 +37,38 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
37
37
dtype : torch .dtype , kv_cache_dtype : Optional [str ],
38
38
block_size : int , use_v1 : bool , use_mla : bool ,
39
39
has_sink : bool ) -> str :
40
- if selected_backend is not None and selected_backend != _Backend .IPEX :
41
- logger .info ("Cannot use %s backend on XPU." , selected_backend )
42
40
use_v1 = envs .VLLM_USE_V1
43
41
if not use_v1 :
44
42
raise ValueError ("XPU backend only supports V1." )
43
+ TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
44
+ FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
45
+ if selected_backend == _Backend .TRITON_ATTN_VLLM_V1 :
46
+ logger .info_once ("Using Triton backend on V1 engine." )
47
+ return TRITON_ATTN_VLLM_V1
48
+ elif selected_backend == _Backend .FLASH_ATTN :
49
+ logger .info_once ("Using Flash Attention backend on V1 engine." )
50
+ return FLASH_ATTN_V1
51
+ elif selected_backend :
52
+ raise ValueError (
53
+ f"Invalid attention backend for { cls .device_name } , "
54
+ f"with use_v1: { use_v1 } use_mla: { use_mla } " )
55
+
45
56
logger .info ("Using Flash Attention backend on V1 engine." )
46
57
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
47
58
59
+ @classmethod
60
+ def is_kv_cache_dtype_supported (cls , kv_cache_dtype : str ,
61
+ model_config : "ModelConfig" ) -> bool :
62
+ """
63
+ Check if the kv_cache_dtype is supported.
64
+ XPU only support fp8 kv cache with triton backend.
65
+ """
66
+ if envs .is_set ("VLLM_ATTENTION_BACKEND" ) and \
67
+ envs .VLLM_ATTENTION_BACKEND == "TRITON_ATTN_VLLM_V1" :
68
+ return kv_cache_dtype in ["fp8_e4m3" , "fp8_e5m2" , "fp8" ]
69
+
70
+ return False
71
+
48
72
@classmethod
49
73
def set_device (cls , device : torch .device ) -> None :
50
74
"""
0 commit comments