Skip to content

Commit fb9e280

Browse files
committed
register xpu custom op as torch.ops.vllm_xpu.xxx
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
1 parent ce90e9f commit fb9e280

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

vllm/kernels/xpu_ops.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import functools
4+
35
import torch
46
from torch import Tensor
7+
from torch.library import Library
58

69
from vllm import ir
710
from vllm.platforms import current_platform
11+
from vllm.utils.torch_utils import direct_register_custom_op
812

913
current_platform.import_kernels()
1014

@@ -18,6 +22,15 @@ def is_xpu_kernels_found() -> bool:
1822
XPU_KERNELS_SUPPORTED = is_xpu_kernels_found()
1923
"""Kernels in this file are supported if vLLM XPU kernels are installed."""
2024

25+
xpu_kernels_lib = Library("vllm_xpu", "FRAGMENT")
26+
"""
27+
This library holds torch ops for vLLM XPU kernels.
28+
"""
29+
direct_register_xpu_op = functools.partial(
30+
direct_register_custom_op, target_lib=xpu_kernels_lib
31+
)
32+
"""Syntactic sugar for registering XPU custom ops."""
33+
2134
rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None
2235

2336

@@ -31,6 +44,19 @@ def rms_norm(
3144
# Kernel requires weight tensor, pass ones
3245
weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
3346
assert variance_size is None
47+
return torch.ops.vllm_xpu.rms_norm(x, weight, epsilon)
48+
49+
50+
def _rms_norm_impl(x: Tensor, weight: Tensor, epsilon: float) -> Tensor:
3451
output = torch.empty(x.shape, device=x.device, dtype=x.dtype)
3552
torch.ops._C.rms_norm(output, x, weight, epsilon)
3653
return output
54+
55+
56+
def _rms_norm_fake(x: Tensor, weight: Tensor, epsilon: float) -> Tensor:
57+
return torch.empty_like(x)
58+
59+
60+
direct_register_xpu_op(
61+
op_name="rms_norm", op_func=_rms_norm_impl, fake_impl=_rms_norm_fake
62+
)

vllm/platforms/xpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def get_default_ir_op_priority(
285285
# use fused kernels where available when no codegen
286286
cc = vllm_config.compilation_config
287287
using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
288-
default = ["native"] if using_inductor else ["xpu_kernels", "vllm_c", "native"]
288+
default = ["native"] if using_inductor else ["xpu_kernels", "native"]
289289

290290
return IrOpPriorityConfig.with_default(default)
291291

0 commit comments

Comments
 (0)