11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+ import functools
4+
35import torch
46from torch import Tensor
7+ from torch .library import Library
58
69from vllm import ir
710from vllm .platforms import current_platform
11+ from vllm .utils .torch_utils import direct_register_custom_op
812
913current_platform .import_kernels ()
1014
@@ -18,6 +22,15 @@ def is_xpu_kernels_found() -> bool:
1822XPU_KERNELS_SUPPORTED = is_xpu_kernels_found ()
1923"""Kernels in this file are supported if vLLM XPU kernels are installed."""
2024
25+ xpu_kernels_lib = Library ("vllm_xpu" , "FRAGMENT" )
26+ """
27+ This library holds torch ops for vLLM XPU kernels.
28+ """
29+ direct_register_xpu_op = functools .partial (
30+ direct_register_custom_op , target_lib = xpu_kernels_lib
31+ )
32+ """Syntactic sugar for registering XPU custom ops."""
33+
2134rms_no_var = lambda x , weight , epsilon , variance_size = None : variance_size is None
2235
2336
@@ -31,6 +44,19 @@ def rms_norm(
3144 # Kernel requires weight tensor, pass ones
3245 weight = torch .ones (x .shape [- 1 ], device = x .device , dtype = x .dtype )
3346 assert variance_size is None
47+ return torch .ops .vllm_xpu .rms_norm (x , weight , epsilon )
48+
49+
50+ def _rms_norm_impl (x : Tensor , weight : Tensor , epsilon : float ) -> Tensor :
3451 output = torch .empty (x .shape , device = x .device , dtype = x .dtype )
3552 torch .ops ._C .rms_norm (output , x , weight , epsilon )
3653 return output
54+
55+
56+ def _rms_norm_fake (x : Tensor , weight : Tensor , epsilon : float ) -> Tensor :
57+ return torch .empty_like (x )
58+
59+
60+ direct_register_xpu_op (
61+ op_name = "rms_norm" , op_func = _rms_norm_impl , fake_impl = _rms_norm_fake
62+ )
0 commit comments