Skip to content

Commit 9bdb06b

Browse files
authored
[XPU][6/N] add xpu scaled_mm kernel (vllm-project#34117)
Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
1 parent caad9f1 commit 9bdb06b

File tree

4 files changed

+67
-10
lines changed

4 files changed

+67
-10
lines changed

.buildkite/scripts/hardware_ci/run-xpu-test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ docker run \
3939
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
4040
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
4141
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN
42+
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --quantization fp8
4243
python3 examples/offline_inference/basic/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager
4344
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2
4445
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -180,18 +180,9 @@ def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
180180
weight_block_size=weight_block_size,
181181
)
182182

183-
def get_xpu_quant_method(
184-
self, layer: torch.nn.Module, prefix: str
185-
) -> "QuantizeMethodBase | None":
186-
raise NotImplementedError(
187-
"FP8 quantization is not supported during xpu kernel migration."
188-
)
189-
190183
def get_quant_method(
191184
self, layer: torch.nn.Module, prefix: str
192185
) -> "QuantizeMethodBase | None":
193-
if current_platform.is_xpu():
194-
return self.get_xpu_quant_method(layer, prefix)
195186
if isinstance(layer, LinearBase):
196187
if is_layer_skipped(
197188
prefix=prefix,
@@ -300,7 +291,7 @@ def __init__(self, quant_config: Fp8Config):
300291
or envs.VLLM_TEST_FORCE_FP8_MARLIN
301292
)
302293
# Disable marlin for rocm
303-
if current_platform.is_rocm():
294+
if current_platform.is_rocm() or current_platform.is_xpu():
304295
self.use_marlin = False
305296
if vllm_is_batch_invariant():
306297
self.use_marlin = False

vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
4040
TritonInt8ScaledMMLinearKernel,
4141
)
42+
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xpu import (
43+
XPUFP8ScaledMMLinearKernel,
44+
)
4245
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
4346
from vllm.platforms import PlatformEnum, current_platform
4447

@@ -72,6 +75,9 @@
7275
PerTensorTorchFP8ScaledMMLinearKernel,
7376
ChannelWiseTorchFP8ScaledMMLinearKernel,
7477
],
78+
PlatformEnum.XPU: [
79+
XPUFP8ScaledMMLinearKernel,
80+
],
7581
}
7682

7783
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from collections.abc import Sequence
5+
6+
import torch
7+
8+
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
9+
FP8ScaledMMLinearKernel,
10+
FP8ScaledMMLinearLayerConfig,
11+
)
12+
from vllm.platforms import current_platform
13+
14+
15+
class XPUFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
16+
@classmethod
17+
def is_supported(
18+
cls, compute_capability: int | None = None
19+
) -> tuple[bool, str | None]:
20+
if not current_platform.is_xpu():
21+
return False, "XPUFP8ScaledMM only support on XPU"
22+
return True, None
23+
24+
@classmethod
25+
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
26+
if c.weight_quant_key.dtype not in {torch.float8_e5m2, torch.float8_e4m3fn}:
27+
return False, "XPUFP8ScaledMM only support FP8 weight dtype"
28+
return True, None
29+
30+
def __init__(
31+
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
32+
) -> None:
33+
assert self.can_implement(c)[0]
34+
assert self.is_supported()[0]
35+
self.config = c
36+
self.layer_param_names = layer_param_names
37+
38+
def apply_weights(
39+
self,
40+
layer: torch.nn.Module,
41+
x: torch.Tensor,
42+
bias: torch.Tensor | None = None,
43+
) -> torch.Tensor:
44+
weight = layer.weight
45+
weight_scale = layer.weight_scale
46+
return torch.ops._xpu_C.fp8_gemm_w8a16(x, weight, weight_scale, bias)
47+
48+
def apply_scaled_mm(
49+
self,
50+
*,
51+
A: torch.Tensor,
52+
B: torch.Tensor,
53+
out_dtype: torch.dtype,
54+
As: torch.Tensor,
55+
Bs: torch.Tensor,
56+
bias: torch.Tensor | None,
57+
output_shape: list,
58+
) -> torch.Tensor:
59+
pass

0 commit comments

Comments
 (0)