Skip to content

Commit 502640c

Browse files
ElizaWszolaProExpertProgmgoin
authored
[Perf] Fix and reapply move apply w8a8 block fp8 linear to class (vllm-project#25696)
Signed-off-by: ElizaWszola <[email protected]> Signed-off-by: ElizaWszola <[email protected]> Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: Luka Govedič <[email protected]> Co-authored-by: Luka Govedič <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
1 parent 3d5f1c8 commit 502640c

File tree

13 files changed

+412
-200
lines changed

13 files changed

+412
-200
lines changed

benchmarks/cutlass_benchmarks/w8a8_benchmarks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from vllm import _custom_ops as ops
1919
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
20-
w8a8_block_fp8_matmul,
20+
w8a8_triton_block_scaled_mm,
2121
)
2222
from vllm.utils import FlexibleArgumentParser, cdiv
2323

@@ -158,7 +158,7 @@ def bench_fp8(
158158
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
159159
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
160160
),
161-
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
161+
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm(
162162
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
163163
),
164164
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(

benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from vllm import _custom_ops as ops
1010
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
1111
per_token_group_quant_fp8,
12-
w8a8_block_fp8_matmul,
12+
w8a8_triton_block_scaled_mm,
1313
)
1414
from vllm.triton_utils import triton
1515
from vllm.utils.deep_gemm import (
@@ -63,7 +63,7 @@ def deepgemm_gemm():
6363

6464
# === vLLM Triton Implementation ===
6565
def vllm_triton_gemm():
66-
return w8a8_block_fp8_matmul(A_vllm,
66+
return w8a8_triton_block_scaled_mm(A_vllm,
6767
B_vllm,
6868
A_scale_vllm,
6969
B_scale_vllm,

tests/kernels/quantization/test_block_fp8.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
native_w8a8_block_matmul)
1212
from vllm.config import VllmConfig
1313
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
14-
cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_block_fp8_matmul)
14+
cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_triton_block_scaled_mm)
1515
from vllm.platforms import current_platform
1616
from vllm.utils import has_deep_gemm
1717
from vllm.utils.deep_gemm import (fp8_gemm_nt,
@@ -91,7 +91,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
9191

9292
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
9393
out_dtype)
94-
out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
94+
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size,
95+
out_dtype)
9596

9697
rel_diff = (torch.mean(
9798
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /

tests/kernels/quantization/test_fp8_quant_group.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
(8, 513, 64), # Non-divisible (native only)
2121
])
2222
@pytest.mark.parametrize("seed", [42])
23+
@pytest.mark.parametrize("use_ue8m0", [True, False])
2324
@torch.inference_mode()
2425
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
25-
group_size: int, seed: int) -> None:
26+
group_size: int, seed: int,
27+
use_ue8m0: bool) -> None:
2628
"""Test QuantFP8 group quantization with various configurations.
2729
2830
Tests both CUDA and native implementations, column-major scales,
@@ -38,7 +40,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
3840
group_shape = GroupShape(1, group_size)
3941
quant_op = QuantFP8(static=False,
4042
group_shape=group_shape,
41-
column_major_scales=False)
43+
column_major_scales=False,
44+
use_ue8m0=use_ue8m0)
4245

4346
# 1. Test native implementation (always available)
4447
x_quant_native, scales_native = quant_op.forward_native(x.clone())
@@ -48,9 +51,15 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
4851
# 2. Test column-major scales configuration
4952
quant_op_col = QuantFP8(static=False,
5053
group_shape=group_shape,
51-
column_major_scales=True)
54+
column_major_scales=True,
55+
use_ue8m0=use_ue8m0)
5256
_, scales_col = quant_op_col.forward_native(x.clone())
53-
assert scales_col.shape == (expected_num_groups, batch_size)
57+
assert scales_col.shape == (batch_size, expected_num_groups)
58+
assert scales_col.stride(0) == 1
59+
assert scales_col.stride(1) == batch_size
60+
61+
# Test column-major scales consistency
62+
assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8)
5463

5564
# 3. Test CUDA implementation (only for divisible dimensions)
5665
if is_divisible:
@@ -68,21 +77,23 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
6877

6978

7079
@pytest.mark.parametrize("seed", [42])
80+
@pytest.mark.parametrize("use_ue8m0", [True, False])
7181
@torch.inference_mode()
72-
def test_quantfp8_group_multidimensional(seed: int) -> None:
82+
def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
7383
current_platform.seed_everything(seed)
7484

7585
group_size = 64
7686

7787
# Test with 3D input
78-
batch1, batch2, hidden_dim = 4, 8, 512
88+
batch1, batch2, hidden_dim = 4, 8, 1024
7989
x_3d = torch.randn(
8090
(batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
8191

8292
group_shape = GroupShape(1, group_size)
8393
quant_op = QuantFP8(static=False,
8494
group_shape=group_shape,
85-
column_major_scales=False)
95+
column_major_scales=False,
96+
use_ue8m0=use_ue8m0)
8697

8798
x_quant, scales = quant_op.forward_native(x_3d.clone())
8899
assert x_quant.shape == x_3d.shape
@@ -91,9 +102,10 @@ def test_quantfp8_group_multidimensional(seed: int) -> None:
91102
# Test column_major_scales with multi-dim
92103
quant_op_col = QuantFP8(static=False,
93104
group_shape=group_shape,
94-
column_major_scales=True)
105+
column_major_scales=True,
106+
use_ue8m0=use_ue8m0)
95107
_, scales_col = quant_op_col.forward_native(x_3d.clone())
96-
assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)
108+
assert scales_col.shape == (batch1, batch2, hidden_dim // group_size)
97109

98110
# Test with 4D input
99111
batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256

tests/model_executor/test_enabled_custom_ops.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from vllm.model_executor.layers.layernorm import (RMSNorm,
1818
dispatch_rocm_rmsnorm_func,
1919
fused_add_rms_norm, rms_norm)
20-
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
21-
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
2220
from vllm.platforms import current_platform
2321

2422
RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
@@ -111,34 +109,6 @@ def test_enabled_ops_invalid(env: str):
111109
RMSNorm(1024).enabled()
112110

113111

114-
@pytest.mark.skipif(
115-
not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(),
116-
reason="AITER is a feature exclusive for ROCm and FP8_FNUZ")
117-
@pytest.mark.parametrize("use_cutlass", [True, False])
118-
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
119-
@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"])
120-
def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str,
121-
use_rocm_aiter_gemm_w8a8_blockscale: str,
122-
monkeypatch):
123-
124-
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
125-
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR",
126-
use_rocm_aiter_gemm_w8a8_blockscale)
127-
128-
use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool(
129-
int(use_rocm_aiter_gemm_w8a8_blockscale)))
130-
block_scale_func = dispatch_w8a8_blockscale_func(
131-
use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported)
132-
if use_cutlass:
133-
assert block_scale_func == cutlass_scaled_mm
134-
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
135-
use_rocm_aiter_gemm_w8a8_blockscale):
136-
assert block_scale_func == (
137-
torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale)
138-
else:
139-
assert block_scale_func == w8a8_block_fp8_matmul
140-
141-
142112
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
143113
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
144114
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)

tests/quantization/test_compressed_tensors.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
1919
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
2020
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
21+
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
22+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
23+
W8A8BlockFp8LinearOp)
2124
from vllm.model_executor.layers.quantization.utils.quant_utils import (
2225
cutlass_fp4_supported)
2326
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@@ -742,3 +745,35 @@ def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt,
742745
perplexity = llm.generate_prompt_perplexity([prompt])[0]
743746
print(perplexity)
744747
assert perplexity <= exp_perplexity
748+
749+
750+
def test_compressed_tensors_fp8_block_enabled(vllm_runner):
751+
model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK"
752+
with vllm_runner(model_path) as llm:
753+
754+
fp8_dtype = current_platform.fp8_dtype()
755+
756+
def check_model(model):
757+
layer = model.model.layers[0]
758+
759+
qkv_proj = layer.self_attn.qkv_proj
760+
assert isinstance(qkv_proj.quant_method,
761+
CompressedTensorsLinearMethod)
762+
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
763+
assert isinstance(qkv_proj.scheme.w8a8_block_fp8_linear,
764+
W8A8BlockFp8LinearOp)
765+
766+
assert qkv_proj.weight.dtype is fp8_dtype
767+
assert qkv_proj.weight_scale.dtype is torch.float32
768+
assert len(qkv_proj.weight.shape) == 2
769+
assert len(qkv_proj.weight_scale.shape) == 2
770+
771+
input_quant_op = \
772+
qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op
773+
assert isinstance(input_quant_op, QuantFP8)
774+
assert input_quant_op._forward_method == input_quant_op.forward_cuda
775+
776+
llm.apply_model(check_model)
777+
778+
output = llm.generate_greedy("Hello my name is", max_tokens=20)
779+
assert output

vllm/config/vllm.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,23 @@ def __post_init__(self):
516516
" by VLLM_DEBUG_DUMP_PATH to %s", env_path)
517517
self.compilation_config.debug_dump_path = env_path
518518

519+
def has_blocked_weights():
520+
if self.quant_config is not None:
521+
if hasattr(self.quant_config, "weight_block_size"):
522+
return self.quant_config.weight_block_size is not None
523+
elif hasattr(self.quant_config, "has_blocked_weights"):
524+
return self.quant_config.has_blocked_weights()
525+
return False
526+
527+
# Enable quant_fp8 CUDA ops (TODO disable in follow up)
528+
# On H100 the CUDA kernel is faster than
529+
# native implementation
530+
# https://github.com/vllm-project/vllm/issues/25094
531+
if has_blocked_weights():
532+
custom_ops = self.compilation_config.custom_ops
533+
if "none" not in custom_ops and "-quant_fp8" not in custom_ops:
534+
custom_ops.append("+quant_fp8")
535+
519536
def update_sizes_for_sequence_parallelism(self,
520537
possible_sizes: list) -> list:
521538
# remove the sizes that not multiple of tp_size when

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,14 @@ def get_cache_scale(self, name: str) -> Optional[str]:
644644
# If no matches, return None
645645
return None
646646

647+
def has_blocked_weights(self) -> bool:
648+
for scheme in self.target_scheme_map.values():
649+
weight_quant = scheme.get("weights")
650+
if (weight_quant is not None
651+
and weight_quant.strategy == QuantizationStrategy.BLOCK):
652+
return True
653+
return False
654+
647655
@staticmethod
648656
def supports_cutlass_24(
649657
weight_quant: Optional[QuantizationArgs],

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
1212
CompressedTensorsScheme)
1313
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
14-
apply_fp8_block_linear, check_aiter_fp8_linear_support,
14+
W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support,
1515
create_fp8_input_scale, create_fp8_scale_parameter,
1616
create_fp8_weight_parameter, maybe_post_process_fp8_weight_block,
1717
process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy,
@@ -41,16 +41,30 @@ def __init__(self, weight_quant: QuantizationArgs,
4141
self.strategy = weight_quant.strategy
4242
self.out_dtype = torch.get_default_dtype()
4343
self.is_static_input_scheme = is_static_input_scheme
44-
self.act_q_group_shape = GroupShape.PER_TENSOR \
45-
if is_static_input_scheme else GroupShape.PER_TOKEN
46-
self.fp8_linear = Fp8LinearOp(
47-
act_quant_static=self.is_static_input_scheme,
48-
act_quant_group_shape=self.act_q_group_shape)
4944

5045
self.weight_block_size = self.weight_quant.block_structure
46+
if self.weight_block_size is not None:
47+
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
48+
else:
49+
self.act_q_group_shape = GroupShape.PER_TENSOR \
50+
if is_static_input_scheme else GroupShape.PER_TOKEN
51+
5152
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
5253
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
5354

55+
if self.weight_block_size is not None:
56+
assert not self.is_static_input_scheme
57+
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
58+
weight_group_shape=GroupShape(*self.weight_block_size),
59+
act_quant_group_shape=self.act_q_group_shape,
60+
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
61+
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
62+
)
63+
else:
64+
self.fp8_linear = Fp8LinearOp(
65+
act_quant_static=self.is_static_input_scheme,
66+
act_quant_group_shape=self.act_q_group_shape)
67+
5468
@classmethod
5569
def get_min_capability(cls) -> int:
5670
# lovelace and up
@@ -142,13 +156,14 @@ def apply_weights(self,
142156
x: torch.Tensor,
143157
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
144158

145-
if layer.weight_block_size is not None:
146-
return apply_fp8_block_linear(
147-
layer,
159+
if self.weight_block_size is not None:
160+
return self.w8a8_block_fp8_linear.apply(
148161
input=x,
162+
weight=layer.weight,
163+
weight_scale=layer.weight_scale,
164+
input_scale=layer.input_scale,
149165
bias=bias,
150-
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
151-
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
166+
)
152167

153168
return self.fp8_linear.apply(input=x,
154169
weight=layer.weight,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
3434
select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
3535
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
36-
apply_fp8_block_linear, check_aiter_fp8_linear_support,
36+
W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support,
3737
create_fp8_input_scale, create_fp8_scale_parameter,
3838
create_fp8_weight_parameter, expert_weight_is_col_major,
3939
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
@@ -242,15 +242,28 @@ def __init__(self, quant_config: Fp8Config):
242242
self.weight_block_size = self.quant_config.weight_block_size
243243
self.block_quant = self.weight_block_size is not None
244244
self.act_q_static = self.quant_config.activation_scheme == "static"
245-
# Use per-token quantization for better perf if dynamic and cutlass
246-
if not self.act_q_static and cutlass_fp8_supported():
247-
self.act_q_group_shape = GroupShape.PER_TOKEN
245+
if self.weight_block_size:
246+
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
248247
else:
249-
self.act_q_group_shape = GroupShape.PER_TENSOR
248+
# Use per-token quantization for better perf if dynamic and cutlass
249+
if not self.act_q_static and cutlass_fp8_supported():
250+
self.act_q_group_shape = GroupShape.PER_TOKEN
251+
else:
252+
self.act_q_group_shape = GroupShape.PER_TENSOR
250253

251-
self.fp8_linear = Fp8LinearOp(
252-
act_quant_static=self.act_q_static,
253-
act_quant_group_shape=self.act_q_group_shape)
254+
if self.block_quant:
255+
assert not self.act_q_static
256+
assert self.weight_block_size is not None
257+
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
258+
weight_group_shape=GroupShape(*self.weight_block_size),
259+
act_quant_group_shape=self.act_q_group_shape,
260+
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
261+
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
262+
)
263+
else:
264+
self.fp8_linear = Fp8LinearOp(
265+
act_quant_static=self.act_q_static,
266+
act_quant_group_shape=self.act_q_group_shape)
254267

255268
def create_weights(
256269
self,
@@ -399,12 +412,15 @@ def apply(self,
399412
bias=bias)
400413

401414
if self.block_quant:
402-
return apply_fp8_block_linear(
403-
layer,
415+
assert self.weight_block_size is not None
416+
417+
return self.w8a8_block_fp8_linear.apply(
404418
input=x,
419+
weight=layer.weight,
420+
weight_scale=layer.weight_scale,
421+
input_scale=layer.input_scale,
405422
bias=bias,
406-
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
407-
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
423+
)
408424

409425
return self.fp8_linear.apply(input=x,
410426
weight=layer.weight,

0 commit comments

Comments
 (0)