diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index c084f410c573..578e32bc73f3 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -6,7 +6,6 @@ from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op - def rocm_aiter_tuned_gemm_impl( input: torch.Tensor, weight: torch.Tensor, @@ -45,6 +44,49 @@ def rocm_aiter_tuned_gemm_fake( return torch.empty((m, n), dtype=out_dtype, device=input.device) +def rocm_aiter_rmsnorm2d_fwd_with_add_quant_impl( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float, x_scale: Optional[torch.Tensor]=None, y_scale_dtype: Optional[torch.dtype]=None, + q_dtype: torch.dtype=torch.float8_e4m3fnuz, model_sensitive: float=0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # TODO: make q_dtype general + import aiter as rocm_aiter + assert y_scale_dtype is not None # TODO + + if x_scale is None: + output = torch.empty(input.shape, dtype=q_dtype, device="cuda") + y_scale = torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda") #TODO: only per-token quant now + if residual is None: + residual_out = None + rocm_aiter.rmsnorm2d_fwd_with_dynamicquant( + output, input, y_scale, weight, variance_epsilon, model_sensitive + ) + elif residual is not None: + residual_out = torch.empty_like(input) + rocm_aiter.rmsnorm2d_fwd_with_add_dynamicquant( + output, input, residual, residual_out, y_scale, weight, variance_epsilon, model_sensitive + ) + else: + output = torch.empty(input.shape, dtype=q_dtype, device="cuda") + y_scale = torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda") #TODO: only per-token quant now + if residual is None: + residual_out = None + rocm_aiter.rmsnorm2d_fwd_with_smoothquant( + output, input, x_scale, y_scale, weight, variance_epsilon, model_sensitive + ) + elif residual is not None: + residual_out = torch.empty_like(input) + rocm_aiter.rmsnorm2d_fwd_with_add_smoothquant( + output, input, residual, residual_out, x_scale, y_scale, weight, variance_epsilon + ) + return output, residual_out, y_scale + +def rocm_aiter_rmsnorm2d_fwd_with_add_quant_fake( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float, x_scale: Optional[torch.Tensor] = None, y_scale_dtype: Optional[torch.dtype] = None, + q_dtype: torch.dtype=torch.float8_e4m3fnuz, model_sensitive: float=0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return torch.empty(input.shape, dtype=q_dtype, device="cuda"), torch.empty_like(input), torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda") + + if current_platform.is_rocm(): direct_register_custom_op( op_name="rocm_aiter_tuned_gemm", @@ -54,6 +96,13 @@ def rocm_aiter_tuned_gemm_fake( dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm2d_fwd_with_add_quant", + op_func=rocm_aiter_rmsnorm2d_fwd_with_add_quant_impl, + mutates_args=[], + fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) class aiter_ops: @@ -73,4 +122,16 @@ def rocm_aiter_tuned_gemm( out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, - ) \ No newline at end of file + ) + + @staticmethod + def rocm_aiter_rmsnorm2d_fwd_with_add_quant( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float, x_scale: Optional[torch.Tensor] = None, y_scale_dtype: Optional[torch.dtype] = None, + q_dtype: torch.dtype=torch.float8_e4m3fnuz, model_sensitive: float=0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add_quant( + input, residual, weight, variance_epsilon, x_scale, + y_scale_dtype, q_dtype, model_sensitive + ) + \ No newline at end of file diff --git a/vllm/envs.py b/vllm/envs.py index 8f0fd84f32f8..a99d3d6c62d4 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -79,6 +79,7 @@ VLLM_ROCM_USE_AITER_MOE: bool = True VLLM_ROCM_USE_AITER_ASMMOE: bool = False VLLM_ROCM_USE_AITER_RMSNORM: bool = True + VLLM_ROCM_USE_AITER_FUSED_RMSNORM_QUANT: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_ROPE: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True @@ -589,6 +590,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_ROCM_CUSTOM_PAGED_ATTN": lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1")), + + # for fused rmsnorm and fp8 quant kernel from aiter + "VLLM_ROCM_USE_AITER_FUSED_RMSNORM_QUANT": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_FUSED_RMSNORM_QUANT", "True").lower() in + ("true", "1")), # Divisor for dynamic query scale factor calculation for FP8 KV Cache "Q_SCALE_CONSTANT": diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 802e0163a41d..4649818e6923 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -381,11 +381,14 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param.data.copy_(loaded_weight) def forward( - self, x: torch.Tensor + self, x: torch.Tensor, input_scale: Optional[torch.Tensor] = None ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: bias = self.bias if not self.skip_bias_add else None assert self.quant_method is not None - output = self.quant_method.apply(self, x, bias) + if input_scale is not None: + output = self.quant_method.apply(self, x, bias, input_scale) + else: + output = self.quant_method.apply(self, x, bias) output_bias = self.bias if self.skip_bias_add else None if not self.return_bias: return output @@ -532,13 +535,14 @@ def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): param.load_column_parallel_weight(loaded_weight=loaded_weight) def forward( - self, input_ + self, input_, input_scale: Optional[torch.Tensor]=None ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: bias = self.bias if not self.skip_bias_add else None # Matrix multiply. assert self.quant_method is not None - output_parallel = self.quant_method.apply(self, input_, bias) + output_parallel = self.quant_method.apply(self, input_, bias, input_scale) + # for fused_rmsnorm_quant usage if self.gather_output: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index cb9a48d7746b..b0c91ea665b6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -566,7 +566,8 @@ def create_weights(self, layer: torch.nn.Module, def apply(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None): + bias: Optional[torch.Tensor] = None, + input_scale: Optional[torch.Tensor] = None): """ Use the output of create_weights and the CompressedTensorsScheme associated with the layer to apply the forward pass with the @@ -577,7 +578,7 @@ def apply(self, scheme = layer.scheme if scheme is None: raise ValueError("A scheme must be defined for each layer") - return scheme.apply_weights(layer, x, bias=bias) + return scheme.apply_weights(layer, x, bias=bias, input_scale=input_scale) class CompressedTensorsKVCacheMethod(BaseKVCacheMethod): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py index daa25d23a306..7a09bb4cc9e0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -32,7 +32,7 @@ def create_weights(self, *args, **kwargs): @abstractmethod def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]): + bias: Optional[torch.Tensor], input_scale: Optional[torch.Tensor]): """ Run the forward pass for the particular scheme. This is where scheme-specific dequant/quant steps/kernels should be applied. @@ -41,6 +41,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, other parameters relevant to the particular scheme. :param x: input to the layer :param bias: bias parameter + :parm input_scale: input scale used for fused_rmsnorm_quant """ raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 59fa4d17333b..e44556dbaa23 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -156,11 +156,14 @@ def create_weights(self, layer: torch.nn.Module, def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + bias: Optional[torch.Tensor] = None, + input_scale: Optional[torch.Tensor]=None) -> torch.Tensor: + + effective_input_scale = input_scale if input_scale is not None else layer.input_scale return self.fp8_linear.apply(input=x, weight=layer.weight, weight_scale=layer.weight_scale, out_dtype=self.out_dtype, - input_scale=layer.input_scale, + input_scale=effective_input_scale, bias=bias) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f286fe824029..70875bd230c0 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -396,7 +396,19 @@ def process_weights_after_loading(self, layer: Module) -> None: def apply(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + bias: Optional[torch.Tensor] = None, + input_scale: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Apply the FP8 linear transformation. + + Args: + layer: The linear layer module + x: Input tensor + bias: Optional bias tensor + input_scale: Optional input scale tensor from fused_rmsnorm_quant + """ + # Use provided input_scale if available, otherwise use layer's input_scale + effective_input_scale = input_scale if input_scale is not None else layer.input_scale if self.use_marlin: return apply_fp8_marlin_linear( @@ -415,7 +427,7 @@ def apply(self, weight=layer.weight, block_size=self.quant_config.weight_block_size, weight_scale=layer.weight_scale_inv, - input_scale=layer.input_scale, + input_scale=effective_input_scale, bias=bias, cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported, @@ -425,7 +437,7 @@ def apply(self, weight=layer.weight, weight_scale=layer.weight_scale, out_dtype=self.out_dtype, - input_scale=layer.input_scale, + input_scale=effective_input_scale, bias=bias) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 774943bd50c5..ac10a346e7c2 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -387,7 +387,6 @@ def apply( # so fallback to naive if per channel or per token else: if input.dtype != current_platform.fp8_dtype(): - if not self.use_aiter_and_is_supported: # Maybe apply padding to output, see comment in __init__ qinput, x_scale = ops.scaled_fp8_quant( @@ -407,6 +406,9 @@ def apply( else: qinput, x_scale = input_2d, input_scale + if envs.VLLM_ROCM_USE_AITER_FUSED_RMSNORM_QUANT and input_scale is not None: + qinput, x_scale = input_2d, input_scale + per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2 per_tensor_activations = (x_scale.numel() diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 23b450aeddac..04f7baef0718 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -57,6 +57,15 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +import vllm.envs as envs +from vllm.platforms import current_platform +from vllm._aiter_ops import aiter_ops + +def is_rocm_aiter_fused_rms_quant_enabled() -> bool: + return current_platform.is_rocm() \ + and envs.VLLM_ROCM_USE_AITER_RMSNORM \ + and envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_ROCM_USE_AITER_FUSED_RMSNORM_QUANT class DeepseekV2MLP(nn.Module): @@ -291,12 +300,20 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, + input_scale: Optional[torch.Tensor] = None , + quant_hidden_states: Optional[torch.Tensor] = None ) -> torch.Tensor: if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] - q = self.q_a_layernorm(q) - q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, - self.qk_head_dim) + if is_rocm_aiter_fused_rms_quant_enabled(): + q, residual_out, q_scale = aiter_ops.rocm_aiter_rmsnorm2d_fwd_with_add_quant(q, residual=None, + weight=self.q_a_layernorm.weight, variance_epsilon=self.q_a_layernorm.variance_epsilon, + x_scale=None, y_scale_dtype=torch.float32) + q = self.q_b_proj(q, q_scale)[0].view(-1, self.num_local_heads, self.qk_head_dim) + else: + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, + self.qk_head_dim) else: q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, self.qk_head_dim) @@ -306,8 +323,14 @@ def forward( kv_a, _ = latent_cache.split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) - kv_a = self.kv_a_layernorm(kv_a.contiguous()) - kv = self.kv_b_proj(kv_a)[0] + if is_rocm_aiter_fused_rms_quant_enabled(): + kv_a, residual_out, kv_a_y_scale = aiter_ops.rocm_aiter_rmsnorm2d_fwd_with_add_quant(kv_a, residual=None, + weight=self.kv_a_layernorm.weight, variance_epsilon=self.kv_a_layernorm.variance_epsilon, + x_scale=None, y_scale_dtype=torch.float32) + kv = self.kv_b_proj(kv_a, kv_a_y_scale)[0] + else: + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) @@ -467,14 +490,19 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + quant_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.q_lora_rank is not None: - ckq = self.q_a_proj(hidden_states)[0] + if is_rocm_aiter_fused_rms_quant_enabled(): + ckq = self.q_a_proj(quant_hidden_states, input_scale)[0] + else: + ckq = self.q_a_proj(hidden_states)[0] hidden_states_or_q_c = self.q_a_layernorm(ckq) else: hidden_states_or_q_c = hidden_states kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) return self.mla_attn(hidden_states_or_q_c, kv_c_normed, @@ -553,16 +581,38 @@ def forward( residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) + if is_rocm_aiter_fused_rms_quant_enabled(): + if residual is None: + q_hidden_states, residual_out, y_scale = aiter_ops.rocm_aiter_rmsnorm2d_fwd_with_add_quant(hidden_states, + residual=None, weight=self.input_layernorm.weight, + variance_epsilon=self.input_layernorm.variance_epsilon, + x_scale=None, y_scale_dtype=torch.float32) + residual = hidden_states + else: # tmp fix for enable cuda graph + q_hidden_states, residual_out, y_scale = aiter_ops.rocm_aiter_rmsnorm2d_fwd_with_add_quant(hidden_states, + residual=residual, weight=self.input_layernorm.weight, + variance_epsilon=self.input_layernorm.variance_epsilon, + x_scale=None, y_scale_dtype=torch.float32) + residual = residual_out + else: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + if is_rocm_aiter_fused_rms_quant_enabled(): + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + input_scale=y_scale, + quant_hidden_states = q_hidden_states + ) + else: + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) if hidden_states.dtype == torch.float16: # Fix FP16 overflow