Skip to content

Commit ce90e9f

Browse files
authored
Support vLLM IR on XPU (#148)
* Support vLLM IR on XPU Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com> * test layernorm on xpu Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com> --------- Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
1 parent c5a04b8 commit ce90e9f

File tree

4 files changed

+63
-10
lines changed

4 files changed

+63
-10
lines changed

tests/kernels/ir/test_layernorm.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,16 @@ def rms_norm_inputs(n_tokens: int, hidden_size: int, dtype: torch.dtype):
2020

2121

2222
@pytest.mark.skipif(
23-
not current_platform.is_cuda_alike(),
24-
reason="Currently only kernels on CUDA and ROCm",
23+
not current_platform.is_cuda_alike() and not current_platform.is_xpu(),
24+
reason="Currently only kernels on CUDA, ROCm and XPU",
2525
)
2626
def test_rms_norm_registration():
2727
expected = {
2828
"native": True,
29-
"vllm_c": True,
29+
"vllm_c": current_platform.is_cuda_alike(),
3030
"aiter": current_platform.is_rocm(),
3131
"oink": False,
32+
"xpu_kernels": current_platform.is_xpu(),
3233
}
3334

3435
actual = {
@@ -43,13 +44,13 @@ def test_rms_norm_registration():
4344
@pytest.mark.parametrize("hidden_size", [16, 4096, 8192])
4445
@pytest.mark.parametrize("epsilon", [1e-6, 1e-5])
4546
@pytest.mark.skipif(
46-
not current_platform.is_cuda_alike(),
47-
reason="Currently only kernels on CUDA and ROCm",
47+
not current_platform.is_cuda_alike() and not current_platform.is_xpu(),
48+
reason="Currently only kernels on CUDA, ROCm and XPU",
4849
)
4950
class TestRMSNorm:
5051
@classmethod
5152
def setup_class(cls, **kwargs):
52-
torch.set_default_device("cuda")
53+
torch.set_default_device(current_platform.device_name)
5354

5455
def test_native_semantics(self, dtype, n_tokens, hidden_size, epsilon):
5556
x, weight = rms_norm_inputs(4, 8, dtype)
@@ -70,7 +71,7 @@ def test_native_semantics(self, dtype, n_tokens, hidden_size, epsilon):
7071
out4 = rms_norm_native(x, None, epsilon=epsilon)
7172
torch.testing.assert_close(out3, out4)
7273

73-
@pytest.mark.parametrize("provider", ["vllm_c", "aiter"])
74+
@pytest.mark.parametrize("provider", ["vllm_c", "aiter", "xpu_kernels"])
7475
def test_impls(self, dtype, n_tokens, hidden_size, epsilon, provider):
7576
impl = ir.ops.rms_norm.impls[provider]
7677
if not impl.supported:
@@ -115,7 +116,7 @@ def test_impls(self, dtype, n_tokens, hidden_size, epsilon, provider):
115116
atol=2e-4,
116117
)
117118

118-
@pytest.mark.parametrize("provider", ["vllm_c", "aiter", "native"])
119+
@pytest.mark.parametrize("provider", ["vllm_c", "aiter", "xpu_kernels", "native"])
119120
def test_torch_opcheck(self, dtype, n_tokens, hidden_size, epsilon, provider):
120121
if not ir.ops.rms_norm.impls[provider].supported:
121122
pytest.skip(f"{provider} impl not supported on this platform")

vllm/kernels/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Kernel implementations for vLLM."""
44

5-
from . import aiter_ops, oink_ops, vllm_c
5+
from . import aiter_ops, oink_ops, vllm_c, xpu_ops
66

7-
__all__ = ["vllm_c", "aiter_ops", "oink_ops"]
7+
__all__ = ["vllm_c", "aiter_ops", "oink_ops", "xpu_ops"]

vllm/kernels/xpu_ops.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import torch
4+
from torch import Tensor
5+
6+
from vllm import ir
7+
from vllm.platforms import current_platform
8+
9+
current_platform.import_kernels()
10+
11+
12+
def is_xpu_kernels_found() -> bool:
13+
from importlib.util import find_spec
14+
15+
return find_spec("vllm_xpu_kernels") is not None
16+
17+
18+
XPU_KERNELS_SUPPORTED = is_xpu_kernels_found()
19+
"""Kernels in this file are supported if vLLM XPU kernels are installed."""
20+
21+
rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None
22+
23+
24+
@ir.ops.rms_norm.register_impl(
25+
"xpu_kernels", supports_args=rms_no_var, supported=XPU_KERNELS_SUPPORTED
26+
)
27+
def rms_norm(
28+
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
29+
) -> Tensor:
30+
if weight is None:
31+
# Kernel requires weight tensor, pass ones
32+
weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
33+
assert variance_size is None
34+
output = torch.empty(x.shape, device=x.device, dtype=x.dtype)
35+
torch.ops._C.rms_norm(output, x, weight, epsilon)
36+
return output

vllm/platforms/xpu.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
if TYPE_CHECKING:
2222
from vllm.config import VllmConfig
23+
from vllm.config.kernel import IrOpPriorityConfig
2324
from vllm.v1.attention.selector import AttentionSelectorConfig
2425
else:
2526
VllmConfig = None
@@ -273,6 +274,21 @@ def get_device_communicator_cls(cls) -> str:
273274
)
274275
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
275276

277+
@classmethod
278+
def get_default_ir_op_priority(
279+
cls, vllm_config: "VllmConfig"
280+
) -> "IrOpPriorityConfig":
281+
from vllm.config.compilation import CompilationMode
282+
from vllm.config.kernel import IrOpPriorityConfig
283+
284+
# Native used by default when compiling,
285+
# use fused kernels where available when no codegen
286+
cc = vllm_config.compilation_config
287+
using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
288+
default = ["native"] if using_inductor else ["xpu_kernels", "vllm_c", "native"]
289+
290+
return IrOpPriorityConfig.with_default(default)
291+
276292
@classmethod
277293
def device_count(cls) -> int:
278294
return torch.xpu.device_count()

0 commit comments

Comments
 (0)