From 494f65312005e1e6c8311da6671f32b12c7f7c5d Mon Sep 17 00:00:00 2001 From: zjli2013 Date: Fri, 25 Jul 2025 20:24:20 +0800 Subject: [PATCH 1/7] init code for fused_rms_quant --- vllm/_aiter_ops.py | 71 ++++++++++++++++++++++- vllm/envs.py | 1 + vllm/model_executor/models/deepseek_v2.py | 21 ++++++- 3 files changed, 89 insertions(+), 4 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index c084f410c573..a197adb685c9 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -45,6 +45,56 @@ def rocm_aiter_tuned_gemm_fake( return torch.empty((m, n), dtype=out_dtype, device=input.device) +def rocm_aiter_rmsnorm2d_fwd_with_add_quant_impl( + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float, x_scale=None, y_scale_dtype=None, + q_dtype="fp8", model_sensitive=0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + + import aiter as rocm_aiter + from aiter import dtypes + quant_dtype_map = { + "i8": dtypes.i8 , + "fp8": dtypes.fp8 + } + q_dtype = quant_dtype_map[q_dtype] + + if x_scale is None: + output = torch.empty(x.shape, dtype=q_dtype, devive="cuda") + y_scale = torch.empty(x.shape[0], 1, dtype=y_scale_dtype, device="cuda") #TODO: only per-token quant now + if residual is None: + residual_out = None + rocm_aiter.rmsnorm2d_fwd_with_dynamicquant( + output, x, y_scale, weight, variance_epsilon, model_sensitive + ) + elif residual is not None: + residual_out = torch.empty_like(x) + rocm_aiter.rmsnorm2d_fwd_with_add_dynamicquant( + output, x, residual, residual_out, y_scale, weight, variance_epsilon, model_sensitive + ) + else: + output = torch.empty(x.shape, dtype=q_dtype, devive="cuda") + y_scale = torch.empty(x.shape[0], 1, dtype=y_scale_dtype, device="cuda") #TODO: only per-token quant now + if residual is None: + residual_out = None + aiter.rmsnorm2d_fwd_with_smoothquant( + output, x, x_scale, y_scale, weight, variance_epsilon, model_sensitive + ) + elif residual is not None: + residual_out = torch.empty_like(x) + out_before_quant = torch.empty_like(x) + aiter.rmsnorm2d_fwd_with_add_smoothquant( + output, input, residual, residual_out, x_scale, y_scale, weight, variance_epsilon, + out_before_quant=out_before_quant, + ) + return output, residual_out, y_scale, out_before_quant + +def rocm_aiter_rmsnorm2d_fwd_with_add_quant_fake( + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float, x_scale=None, y_scale_dtype=None, + q_dtype="fp8", model_sensitive=0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return torch.empty(x.shape, dtype=q_dtype, device="cuda"), torch.empty_like(x), torch.empty(x.shape[0], 1, dtype=y_scale_dtype, device="cuda"), torch.empty_list(x) + + if current_platform.is_rocm(): direct_register_custom_op( op_name="rocm_aiter_tuned_gemm", @@ -54,6 +104,13 @@ def rocm_aiter_tuned_gemm_fake( dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm2d_fwd_with_add_quant", + op_func=rocm_aiter_rmsnorm2d_fwd_with_add_quant_impl, + mutates_args=[], + fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) class aiter_ops: @@ -73,4 +130,16 @@ def rocm_aiter_tuned_gemm( out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, - ) \ No newline at end of file + ) + + @staticmethod + def rocm_aiter_rmsnorm2d_fwd_with_add_quant( + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float, x_scale=None, y_scale_dtype=None, + q_dtype="fp8", model_sensitive=0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + + return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add_quant( + x, residual, weight, variance_epsilon, x_scale, + y_scale_dtype, q_dtype, model_sensitive + ) + \ No newline at end of file diff --git a/vllm/envs.py b/vllm/envs.py index 8f0fd84f32f8..dbade806f821 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -79,6 +79,7 @@ VLLM_ROCM_USE_AITER_MOE: bool = True VLLM_ROCM_USE_AITER_ASMMOE: bool = False VLLM_ROCM_USE_AITER_RMSNORM: bool = True + VLLM_ROCM_USE_AITER_FUSED_RMSNORM_QUANT: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_ROPE: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 23b450aeddac..6b791442b02e 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -57,6 +57,14 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +import vllm.envs as envs +from vllm.platforms import current_platform +from vllm._aiter_ops import aiter_ops + +def is_rocm_aiter_rmsnorm_enabled() -> bool: + return current_platform.is_rocm() \ + and envs.VLLM_ROCM_USE_AITER_RMSNORM \ + and envs.VLLM_ROCM_USE_AITER class DeepseekV2MLP(nn.Module): @@ -294,9 +302,16 @@ def forward( ) -> torch.Tensor: if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] - q = self.q_a_layernorm(q) - q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, - self.qk_head_dim) + if is_rocm_aiter_rmsnorm_enabled(): + q_dtype = "fp8" # TODO, get quant type correctly + q, residual_out, y_scale, q_before_quant = aiter_ops.rocm_aiter_rmsnorm2d_fwd_with_add_quant(q, None, + self.q_a_layernorm.weight, self.q_a_layernorm.variance_epsilon, + None, torch.float32, q_dtype) + + else: + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, + self.qk_head_dim) else: q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, self.qk_head_dim) From 97634271cc876e301c308c7a5f892282359fab05 Mon Sep 17 00:00:00 2001 From: zjli2013 Date: Mon, 28 Jul 2025 10:04:36 +0800 Subject: [PATCH 2/7] add fuse_rmsnorm_quant into deepseek-v2 q_b_proj for verification --- vllm/_aiter_ops.py | 35 +++++++++---------- vllm/model_executor/layers/linear.py | 5 +-- .../model_executor/layers/quantization/fp8.py | 18 ++++++++-- .../layers/quantization/utils/w8a8_utils.py | 34 ++++++++++-------- vllm/model_executor/models/deepseek_v2.py | 12 +++---- 5 files changed, 60 insertions(+), 44 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index a197adb685c9..b93f0c374b03 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -46,7 +46,7 @@ def rocm_aiter_tuned_gemm_fake( def rocm_aiter_rmsnorm2d_fwd_with_add_quant_impl( - x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, variance_epsilon: float, x_scale=None, y_scale_dtype=None, q_dtype="fp8", model_sensitive=0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -57,42 +57,41 @@ def rocm_aiter_rmsnorm2d_fwd_with_add_quant_impl( "fp8": dtypes.fp8 } q_dtype = quant_dtype_map[q_dtype] + assert y_scale_dtype is not None # TODO if x_scale is None: - output = torch.empty(x.shape, dtype=q_dtype, devive="cuda") - y_scale = torch.empty(x.shape[0], 1, dtype=y_scale_dtype, device="cuda") #TODO: only per-token quant now + output = torch.empty(input.shape, dtype=q_dtype, devive="cuda") + y_scale = torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda") #TODO: only per-token quant now if residual is None: residual_out = None rocm_aiter.rmsnorm2d_fwd_with_dynamicquant( - output, x, y_scale, weight, variance_epsilon, model_sensitive + output, input, y_scale, weight, variance_epsilon, model_sensitive ) elif residual is not None: - residual_out = torch.empty_like(x) + residual_out = torch.empty_like(input) rocm_aiter.rmsnorm2d_fwd_with_add_dynamicquant( - output, x, residual, residual_out, y_scale, weight, variance_epsilon, model_sensitive + output, input, residual, residual_out, y_scale, weight, variance_epsilon, model_sensitive ) else: - output = torch.empty(x.shape, dtype=q_dtype, devive="cuda") - y_scale = torch.empty(x.shape[0], 1, dtype=y_scale_dtype, device="cuda") #TODO: only per-token quant now + output = torch.empty(input.shape, dtype=q_dtype, devive="cuda") + y_scale = torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda") #TODO: only per-token quant now if residual is None: residual_out = None aiter.rmsnorm2d_fwd_with_smoothquant( - output, x, x_scale, y_scale, weight, variance_epsilon, model_sensitive + output, input, x_scale, y_scale, weight, variance_epsilon, model_sensitive ) elif residual is not None: - residual_out = torch.empty_like(x) - out_before_quant = torch.empty_like(x) + residual_out = torch.empty_like(input) aiter.rmsnorm2d_fwd_with_add_smoothquant( - output, input, residual, residual_out, x_scale, y_scale, weight, variance_epsilon, - out_before_quant=out_before_quant, + output, input, residual, residual_out, x_scale, y_scale, weight, variance_epsilon ) - return output, residual_out, y_scale, out_before_quant + return output, residual_out, y_scale def rocm_aiter_rmsnorm2d_fwd_with_add_quant_fake( - x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, variance_epsilon: float, x_scale=None, y_scale_dtype=None, q_dtype="fp8", model_sensitive=0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - return torch.empty(x.shape, dtype=q_dtype, device="cuda"), torch.empty_like(x), torch.empty(x.shape[0], 1, dtype=y_scale_dtype, device="cuda"), torch.empty_list(x) + return torch.empty(input.shape, dtype=q_dtype, device="cuda"), torch.empty_like(input), torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda") if current_platform.is_rocm(): @@ -134,12 +133,12 @@ def rocm_aiter_tuned_gemm( @staticmethod def rocm_aiter_rmsnorm2d_fwd_with_add_quant( - x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, variance_epsilon: float, x_scale=None, y_scale_dtype=None, q_dtype="fp8", model_sensitive=0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add_quant( - x, residual, weight, variance_epsilon, x_scale, + input, residual, weight, variance_epsilon, x_scale, y_scale_dtype, q_dtype, model_sensitive ) \ No newline at end of file diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 802e0163a41d..3466cd43ab5c 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -532,13 +532,14 @@ def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): param.load_column_parallel_weight(loaded_weight=loaded_weight) def forward( - self, input_ + self, input_, input_scale: Optional[torch.Tensor]=None ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: bias = self.bias if not self.skip_bias_add else None # Matrix multiply. assert self.quant_method is not None - output_parallel = self.quant_method.apply(self, input_, bias) + output_parallel = self.quant_method.apply(self, input_, bias, input_scale) + # for fused_rmsnorm_quant usage if self.gather_output: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f286fe824029..70875bd230c0 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -396,7 +396,19 @@ def process_weights_after_loading(self, layer: Module) -> None: def apply(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + bias: Optional[torch.Tensor] = None, + input_scale: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Apply the FP8 linear transformation. + + Args: + layer: The linear layer module + x: Input tensor + bias: Optional bias tensor + input_scale: Optional input scale tensor from fused_rmsnorm_quant + """ + # Use provided input_scale if available, otherwise use layer's input_scale + effective_input_scale = input_scale if input_scale is not None else layer.input_scale if self.use_marlin: return apply_fp8_marlin_linear( @@ -415,7 +427,7 @@ def apply(self, weight=layer.weight, block_size=self.quant_config.weight_block_size, weight_scale=layer.weight_scale_inv, - input_scale=layer.input_scale, + input_scale=effective_input_scale, bias=bias, cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported, @@ -425,7 +437,7 @@ def apply(self, weight=layer.weight, weight_scale=layer.weight_scale, out_dtype=self.out_dtype, - input_scale=layer.input_scale, + input_scale=effective_input_scale, bias=bias) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 774943bd50c5..b52e88509a95 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -387,23 +387,27 @@ def apply( # so fallback to naive if per channel or per token else: if input.dtype != current_platform.fp8_dtype(): - - if not self.use_aiter_and_is_supported: - # Maybe apply padding to output, see comment in __init__ - qinput, x_scale = ops.scaled_fp8_quant( - input_2d, - input_scale, - num_token_padding=self.output_padding, - use_per_token_if_dynamic=use_per_token_if_dynamic) + # # Handle fused RMSNorm + quant case + if envs.VLLM_ROCM_USE_AITER_FUSED_RMSNORM_QUANT and input_scale is not None: + print(f"[DEBUG FUSED_RMS_QUANT], input_scale in Fp8LinearOp") + qinput, x_scale = input_2d, input_scale else: - if use_per_token_if_dynamic: - qinput, x_scale = ( - torch.ops.vllm.rocm_aiter_per_token_quant_fp8( - input_2d, scale=input_scale)) + if not self.use_aiter_and_is_supported: + # Maybe apply padding to output, see comment in __init__ + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + num_token_padding=self.output_padding, + use_per_token_if_dynamic=use_per_token_if_dynamic) else: - qinput, x_scale = ( - torch.ops.vllm.rocm_aiter_per_tensor_quant_fp8( - input_2d, scale=input_scale)) + if use_per_token_if_dynamic: + qinput, x_scale = ( + torch.ops.vllm.rocm_aiter_per_token_quant_fp8( + input_2d, scale=input_scale)) + else: + qinput, x_scale = ( + torch.ops.vllm.rocm_aiter_per_tensor_quant_fp8( + input_2d, scale=input_scale)) else: qinput, x_scale = input_2d, input_scale diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 6b791442b02e..5ff7011be9f7 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -64,7 +64,8 @@ def is_rocm_aiter_rmsnorm_enabled() -> bool: return current_platform.is_rocm() \ and envs.VLLM_ROCM_USE_AITER_RMSNORM \ - and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_ROCM_USE_AITER_FUSED_RMSNORM_QUANT class DeepseekV2MLP(nn.Module): @@ -303,11 +304,10 @@ def forward( if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] if is_rocm_aiter_rmsnorm_enabled(): - q_dtype = "fp8" # TODO, get quant type correctly - q, residual_out, y_scale, q_before_quant = aiter_ops.rocm_aiter_rmsnorm2d_fwd_with_add_quant(q, None, - self.q_a_layernorm.weight, self.q_a_layernorm.variance_epsilon, - None, torch.float32, q_dtype) - + q, residual_out, y_scale = aiter_ops.rocm_aiter_rmsnorm2d_fwd_with_add_quant(q, residual=None, + weight=self.q_a_layernorm.weight, variance_epsilon=self.q_a_layernorm.variance_epsilon, + x_scale=None, y_scale_type=torch.float32, q_dtype="fp8") + q = self.q_b_proj(q, y_scale).view(-1, self.num_local_heads, self.qk_head_dim) else: q = self.q_a_layernorm(q) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, From 25a225762d71e98cbcc0ceb803225f43406050de Mon Sep 17 00:00:00 2001 From: zjli2013 Date: Mon, 28 Jul 2025 15:18:03 +0800 Subject: [PATCH 3/7] type error fixed --- vllm/_aiter_ops.py | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index b93f0c374b03..578e32bc73f3 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -6,7 +6,6 @@ from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op - def rocm_aiter_tuned_gemm_impl( input: torch.Tensor, weight: torch.Tensor, @@ -47,20 +46,14 @@ def rocm_aiter_tuned_gemm_fake( def rocm_aiter_rmsnorm2d_fwd_with_add_quant_impl( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float, x_scale=None, y_scale_dtype=None, - q_dtype="fp8", model_sensitive=0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - + variance_epsilon: float, x_scale: Optional[torch.Tensor]=None, y_scale_dtype: Optional[torch.dtype]=None, + q_dtype: torch.dtype=torch.float8_e4m3fnuz, model_sensitive: float=0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # TODO: make q_dtype general import aiter as rocm_aiter - from aiter import dtypes - quant_dtype_map = { - "i8": dtypes.i8 , - "fp8": dtypes.fp8 - } - q_dtype = quant_dtype_map[q_dtype] assert y_scale_dtype is not None # TODO if x_scale is None: - output = torch.empty(input.shape, dtype=q_dtype, devive="cuda") + output = torch.empty(input.shape, dtype=q_dtype, device="cuda") y_scale = torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda") #TODO: only per-token quant now if residual is None: residual_out = None @@ -73,24 +66,24 @@ def rocm_aiter_rmsnorm2d_fwd_with_add_quant_impl( output, input, residual, residual_out, y_scale, weight, variance_epsilon, model_sensitive ) else: - output = torch.empty(input.shape, dtype=q_dtype, devive="cuda") + output = torch.empty(input.shape, dtype=q_dtype, device="cuda") y_scale = torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda") #TODO: only per-token quant now if residual is None: residual_out = None - aiter.rmsnorm2d_fwd_with_smoothquant( + rocm_aiter.rmsnorm2d_fwd_with_smoothquant( output, input, x_scale, y_scale, weight, variance_epsilon, model_sensitive ) elif residual is not None: residual_out = torch.empty_like(input) - aiter.rmsnorm2d_fwd_with_add_smoothquant( + rocm_aiter.rmsnorm2d_fwd_with_add_smoothquant( output, input, residual, residual_out, x_scale, y_scale, weight, variance_epsilon ) return output, residual_out, y_scale def rocm_aiter_rmsnorm2d_fwd_with_add_quant_fake( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float, x_scale=None, y_scale_dtype=None, - q_dtype="fp8", model_sensitive=0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + variance_epsilon: float, x_scale: Optional[torch.Tensor] = None, y_scale_dtype: Optional[torch.dtype] = None, + q_dtype: torch.dtype=torch.float8_e4m3fnuz, model_sensitive: float=0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: return torch.empty(input.shape, dtype=q_dtype, device="cuda"), torch.empty_like(input), torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda") @@ -134,8 +127,8 @@ def rocm_aiter_tuned_gemm( @staticmethod def rocm_aiter_rmsnorm2d_fwd_with_add_quant( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float, x_scale=None, y_scale_dtype=None, - q_dtype="fp8", model_sensitive=0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + variance_epsilon: float, x_scale: Optional[torch.Tensor] = None, y_scale_dtype: Optional[torch.dtype] = None, + q_dtype: torch.dtype=torch.float8_e4m3fnuz, model_sensitive: float=0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add_quant( input, residual, weight, variance_epsilon, x_scale, From a2b71f4b4836448f7c52ac791c58077a0e32b70a Mon Sep 17 00:00:00 2001 From: zjli2013 Date: Mon, 28 Jul 2025 17:10:48 +0800 Subject: [PATCH 4/7] add envs for fused_rms_quant and tmp fix for linear bug --- vllm/envs.py | 5 +++++ vllm/model_executor/layers/linear.py | 6 +++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/vllm/envs.py b/vllm/envs.py index dbade806f821..a99d3d6c62d4 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -590,6 +590,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_ROCM_CUSTOM_PAGED_ATTN": lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1")), + + # for fused rmsnorm and fp8 quant kernel from aiter + "VLLM_ROCM_USE_AITER_FUSED_RMSNORM_QUANT": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_FUSED_RMSNORM_QUANT", "True").lower() in + ("true", "1")), # Divisor for dynamic query scale factor calculation for FP8 KV Cache "Q_SCALE_CONSTANT": diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 3466cd43ab5c..be9003510728 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -538,7 +538,11 @@ def forward( # Matrix multiply. assert self.quant_method is not None - output_parallel = self.quant_method.apply(self, input_, bias, input_scale) + # TODO: bug for gate_up_proj(x) + if input_scale is not None: + output_parallel = self.quant_method.apply(self, input_, bias, input_scale) + else: + output_parallel = self.quant_method.apply(self, input_, bias) # for fused_rmsnorm_quant usage if self.gather_output: # All-gather across the partitions. From 9effd57f42d49e118f5463124a333c39a845babb Mon Sep 17 00:00:00 2001 From: zjli2013 Date: Mon, 28 Jul 2025 21:32:16 +0800 Subject: [PATCH 5/7] add inputscale in tensorCompressed related classes --- vllm/model_executor/layers/linear.py | 6 +----- .../quantization/compressed_tensors/compressed_tensors.py | 5 +++-- .../schemes/compressed_tensors_scheme.py | 3 ++- .../schemes/compressed_tensors_w8a8_fp8.py | 7 +++++-- vllm/model_executor/models/deepseek_v2.py | 8 ++++---- 5 files changed, 15 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index be9003510728..3466cd43ab5c 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -538,11 +538,7 @@ def forward( # Matrix multiply. assert self.quant_method is not None - # TODO: bug for gate_up_proj(x) - if input_scale is not None: - output_parallel = self.quant_method.apply(self, input_, bias, input_scale) - else: - output_parallel = self.quant_method.apply(self, input_, bias) + output_parallel = self.quant_method.apply(self, input_, bias, input_scale) # for fused_rmsnorm_quant usage if self.gather_output: # All-gather across the partitions. diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index cb9a48d7746b..b0c91ea665b6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -566,7 +566,8 @@ def create_weights(self, layer: torch.nn.Module, def apply(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None): + bias: Optional[torch.Tensor] = None, + input_scale: Optional[torch.Tensor] = None): """ Use the output of create_weights and the CompressedTensorsScheme associated with the layer to apply the forward pass with the @@ -577,7 +578,7 @@ def apply(self, scheme = layer.scheme if scheme is None: raise ValueError("A scheme must be defined for each layer") - return scheme.apply_weights(layer, x, bias=bias) + return scheme.apply_weights(layer, x, bias=bias, input_scale=input_scale) class CompressedTensorsKVCacheMethod(BaseKVCacheMethod): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py index daa25d23a306..7a09bb4cc9e0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -32,7 +32,7 @@ def create_weights(self, *args, **kwargs): @abstractmethod def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]): + bias: Optional[torch.Tensor], input_scale: Optional[torch.Tensor]): """ Run the forward pass for the particular scheme. This is where scheme-specific dequant/quant steps/kernels should be applied. @@ -41,6 +41,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, other parameters relevant to the particular scheme. :param x: input to the layer :param bias: bias parameter + :parm input_scale: input scale used for fused_rmsnorm_quant """ raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 59fa4d17333b..e44556dbaa23 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -156,11 +156,14 @@ def create_weights(self, layer: torch.nn.Module, def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + bias: Optional[torch.Tensor] = None, + input_scale: Optional[torch.Tensor]=None) -> torch.Tensor: + + effective_input_scale = input_scale if input_scale is not None else layer.input_scale return self.fp8_linear.apply(input=x, weight=layer.weight, weight_scale=layer.weight_scale, out_dtype=self.out_dtype, - input_scale=layer.input_scale, + input_scale=effective_input_scale, bias=bias) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 5ff7011be9f7..7599d0fe1b77 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -61,7 +61,7 @@ from vllm.platforms import current_platform from vllm._aiter_ops import aiter_ops -def is_rocm_aiter_rmsnorm_enabled() -> bool: +def is_rocm_aiter_fused_rms_quant_enabled() -> bool: return current_platform.is_rocm() \ and envs.VLLM_ROCM_USE_AITER_RMSNORM \ and envs.VLLM_ROCM_USE_AITER \ @@ -303,11 +303,11 @@ def forward( ) -> torch.Tensor: if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] - if is_rocm_aiter_rmsnorm_enabled(): + if is_rocm_aiter_fused_rms_quant_enabled(): q, residual_out, y_scale = aiter_ops.rocm_aiter_rmsnorm2d_fwd_with_add_quant(q, residual=None, weight=self.q_a_layernorm.weight, variance_epsilon=self.q_a_layernorm.variance_epsilon, - x_scale=None, y_scale_type=torch.float32, q_dtype="fp8") - q = self.q_b_proj(q, y_scale).view(-1, self.num_local_heads, self.qk_head_dim) + x_scale=None, y_scale_dtype=torch.float32) + q = self.q_b_proj(q, y_scale)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: q = self.q_a_layernorm(q) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, From a6d66744897f7faa05c4520e8b39bdf1c2a856bf Mon Sep 17 00:00:00 2001 From: zjli2013 Date: Tue, 29 Jul 2025 13:46:06 +0800 Subject: [PATCH 6/7] add fused_rms_quant in DecoderLayer --- vllm/model_executor/layers/linear.py | 7 ++- .../layers/quantization/utils/w8a8_utils.py | 36 ++++++----- vllm/model_executor/models/deepseek_v2.py | 61 ++++++++++++++----- 3 files changed, 67 insertions(+), 37 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 3466cd43ab5c..4649818e6923 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -381,11 +381,14 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param.data.copy_(loaded_weight) def forward( - self, x: torch.Tensor + self, x: torch.Tensor, input_scale: Optional[torch.Tensor] = None ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: bias = self.bias if not self.skip_bias_add else None assert self.quant_method is not None - output = self.quant_method.apply(self, x, bias) + if input_scale is not None: + output = self.quant_method.apply(self, x, bias, input_scale) + else: + output = self.quant_method.apply(self, x, bias) output_bias = self.bias if self.skip_bias_add else None if not self.return_bias: return output diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index b52e88509a95..ac10a346e7c2 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -387,30 +387,28 @@ def apply( # so fallback to naive if per channel or per token else: if input.dtype != current_platform.fp8_dtype(): - # # Handle fused RMSNorm + quant case - if envs.VLLM_ROCM_USE_AITER_FUSED_RMSNORM_QUANT and input_scale is not None: - print(f"[DEBUG FUSED_RMS_QUANT], input_scale in Fp8LinearOp") - qinput, x_scale = input_2d, input_scale + if not self.use_aiter_and_is_supported: + # Maybe apply padding to output, see comment in __init__ + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + num_token_padding=self.output_padding, + use_per_token_if_dynamic=use_per_token_if_dynamic) else: - if not self.use_aiter_and_is_supported: - # Maybe apply padding to output, see comment in __init__ - qinput, x_scale = ops.scaled_fp8_quant( - input_2d, - input_scale, - num_token_padding=self.output_padding, - use_per_token_if_dynamic=use_per_token_if_dynamic) + if use_per_token_if_dynamic: + qinput, x_scale = ( + torch.ops.vllm.rocm_aiter_per_token_quant_fp8( + input_2d, scale=input_scale)) else: - if use_per_token_if_dynamic: - qinput, x_scale = ( - torch.ops.vllm.rocm_aiter_per_token_quant_fp8( - input_2d, scale=input_scale)) - else: - qinput, x_scale = ( - torch.ops.vllm.rocm_aiter_per_tensor_quant_fp8( - input_2d, scale=input_scale)) + qinput, x_scale = ( + torch.ops.vllm.rocm_aiter_per_tensor_quant_fp8( + input_2d, scale=input_scale)) else: qinput, x_scale = input_2d, input_scale + if envs.VLLM_ROCM_USE_AITER_FUSED_RMSNORM_QUANT and input_scale is not None: + qinput, x_scale = input_2d, input_scale + per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2 per_tensor_activations = (x_scale.numel() diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 7599d0fe1b77..9cbe69702103 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -300,14 +300,16 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, + input_scale: Optional[torch.Tensor] = None , + quant_hidden_states: Optional[torch.Tensor] = None ) -> torch.Tensor: if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] if is_rocm_aiter_fused_rms_quant_enabled(): - q, residual_out, y_scale = aiter_ops.rocm_aiter_rmsnorm2d_fwd_with_add_quant(q, residual=None, + q, residual_out, q_scale = aiter_ops.rocm_aiter_rmsnorm2d_fwd_with_add_quant(q, residual=None, weight=self.q_a_layernorm.weight, variance_epsilon=self.q_a_layernorm.variance_epsilon, x_scale=None, y_scale_dtype=torch.float32) - q = self.q_b_proj(q, y_scale)[0].view(-1, self.num_local_heads, self.qk_head_dim) + q = self.q_b_proj(q, q_scale)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: q = self.q_a_layernorm(q) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, @@ -321,8 +323,14 @@ def forward( kv_a, _ = latent_cache.split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) - kv_a = self.kv_a_layernorm(kv_a.contiguous()) - kv = self.kv_b_proj(kv_a)[0] + if is_rocm_aiter_fused_rms_quant_enabled(): + kv_a, residual_out, kv_a_y_scale = aiter_ops.rocm_aiter_rmsnorm2d_fwd_with_add_quant(kv_a, residual=None, + weight=self.kv_a_layernorm.weight, variance_epsilon=self.kv_a_layernorm.variance_epsilon, + x_scale=None, y_scale_dtype=torch.float32) + kv = self.kv_b_proj(kv_a, kv_a_y_scale)[0] + else: + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) @@ -482,14 +490,19 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + quant_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.q_lora_rank is not None: - ckq = self.q_a_proj(hidden_states)[0] + if is_rocm_aiter_fused_rms_quant_enabled(): + ckq = self.q_a_proj(quant_hidden_states, input_scale)[0] + else: + ckq = self.q_a_proj(hidden_states)[0] hidden_states_or_q_c = self.q_a_layernorm(ckq) else: hidden_states_or_q_c = hidden_states kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) return self.mla_attn(hidden_states_or_q_c, kv_c_normed, @@ -568,16 +581,32 @@ def forward( residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) + if is_rocm_aiter_fused_rms_quant_enabled(): + q_hidden_states, residual, y_scale = aiter_ops.rocm_aiter_rmsnorm2d_fwd_with_add_quant(hidden_states, + residual=residual, weight=self.input_layernorm.weight, + variance_epsilon=self.input_layernorm.variance_epsilon, + x_scale=None, y_scale_dtype=torch.float32) + if residual is None: + residual = hidden_states + else: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + if is_rocm_aiter_fused_rms_quant_enabled(): + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + input_scale=y_scale, + quant_hidden_states = q_hidden_states + ) + else: + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) if hidden_states.dtype == torch.float16: # Fix FP16 overflow From afa56ed7845d459d2b8775ad71a1b0550f15354d Mon Sep 17 00:00:00 2001 From: zjli2013 Date: Tue, 29 Jul 2025 15:06:58 +0800 Subject: [PATCH 7/7] enable_cuda_graph fix --- vllm/model_executor/models/deepseek_v2.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 9cbe69702103..04f7baef0718 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -582,12 +582,18 @@ def forward( ) -> torch.Tensor: # Self Attention if is_rocm_aiter_fused_rms_quant_enabled(): - q_hidden_states, residual, y_scale = aiter_ops.rocm_aiter_rmsnorm2d_fwd_with_add_quant(hidden_states, - residual=residual, weight=self.input_layernorm.weight, + if residual is None: + q_hidden_states, residual_out, y_scale = aiter_ops.rocm_aiter_rmsnorm2d_fwd_with_add_quant(hidden_states, + residual=None, weight=self.input_layernorm.weight, variance_epsilon=self.input_layernorm.variance_epsilon, x_scale=None, y_scale_dtype=torch.float32) - if residual is None: residual = hidden_states + else: # tmp fix for enable cuda graph + q_hidden_states, residual_out, y_scale = aiter_ops.rocm_aiter_rmsnorm2d_fwd_with_add_quant(hidden_states, + residual=residual, weight=self.input_layernorm.weight, + variance_epsilon=self.input_layernorm.variance_epsilon, + x_scale=None, y_scale_dtype=torch.float32) + residual = residual_out else: if residual is None: residual = hidden_states