5656 is_pp_missing_parameter ,
5757 make_empty_intermediate_tensors_factory , make_layers ,
5858 maybe_prefix )
59+ from vllm .model_executor .layers .quantization .fp8 import Fp8LinearMethod
60+ from vllm .model_executor .layers .quantization .quark .quark import QuarkLinearMethod
61+ from vllm .model_executor .layers .quantization .quark .schemes .quark_w4a4_mxfp4 import QuarkW4A4MXFP4
5962
6063from vllm .platforms import current_platform
6164from vllm .logger import init_logger
6265logger = init_logger (__name__ )
6366
6467if current_platform .is_rocm () and envs .VLLM_ROCM_USE_AITER :
65- from vllm .model_executor .layers .activation import VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT
68+ from vllm .model_executor .layers .activation import VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT , VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT
69+ from vllm .model_executor .layers .layernorm import VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT
6670 VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs .VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
6771else :
6872 VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT = False
73+ VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT = False
74+ VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT = False
6975 VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False
7076
7177VLLM_ROCM_USE_AITER_MHA = envs .VLLM_ROCM_USE_AITER_MHA
@@ -104,15 +110,23 @@ def __init__(
104110 if hidden_act != "silu" :
105111 raise ValueError (f"Unsupported activation: { hidden_act } . "
106112 "Only silu is supported for now." )
107- self .block_quant = hasattr (quant_config , "weight_block_size" ) and quant_config .weight_block_size is not None
113+ self .block_quant = isinstance ( self . down_proj . quant_method , Fp8LinearMethod ) and hasattr (quant_config , "weight_block_size" ) and quant_config .weight_block_size is not None
108114 if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT and not self .block_quant :
109- logger .info ("[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT will not be activated because this model is not using blocked quantization" )
115+ logger .info (f"[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT will not be activated because { self .__class__ .__name__ } is not using FP8 blockscale GEMM" )
116+ self .fp4_block_quant_gemm = (isinstance (self .down_proj .quant_method , QuarkLinearMethod ) and hasattr (self .down_proj , "scheme" ) and isinstance (self .down_proj .scheme , QuarkW4A4MXFP4 ))
117+ if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT and not self .fp4_block_quant_gemm :
118+ logger .info (f"[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT will not be activated because { self .__class__ .__name__ } is not using FP4 blockscale GEMM" )
110119 self .act_fn = SiluAndMul ()
111120
112121 def forward (self , x ):
113- x , _ = self .gate_up_proj (x )
114- if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT and self .block_quant :
115- x = torch .ops .vllm .act_mul_and_fp8_group_quant (x )
122+ x_quant_scales = None
123+ if isinstance (x , tuple ):
124+ x , x_quant_scales = x
125+ x , _ = self .gate_up_proj (x , x_quant_scales = x_quant_scales )
126+ if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT and self .fp4_block_quant_gemm :
127+ x = torch .ops .vllm .rocm_aiter_act_mul_and_fp4_group_quant (x )
128+ elif VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT and self .block_quant :
129+ x = torch .ops .vllm .rocm_aiter_act_mul_and_fp8_group_quant (x )
116130 else :
117131 x = self .act_fn (x )
118132 x_quant_scales = None
@@ -220,7 +234,11 @@ def forward(
220234 positions : torch .Tensor ,
221235 hidden_states : torch .Tensor ,
222236 ) -> torch .Tensor :
223- qkv , _ = self .qkv_proj (hidden_states )
237+ hidden_states_quant = None
238+ if isinstance (hidden_states , tuple ):
239+ hidden_states , hidden_states_quant = hidden_states
240+
241+ qkv , _ = self .qkv_proj (hidden_states , x_quant_scales = hidden_states_quant )
224242 q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
225243
226244 if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE :
@@ -316,6 +334,12 @@ def __init__(
316334 eps = config .rms_norm_eps )
317335 self .post_attention_layernorm = RMSNorm (config .hidden_size ,
318336 eps = config .rms_norm_eps )
337+ self .input_layernorm_with_fp4_block_quant_gemm = (isinstance (self .self_attn .qkv_proj .quant_method , QuarkLinearMethod ) and hasattr (self .self_attn .qkv_proj , "scheme" ) and isinstance (self .self_attn .qkv_proj .scheme , QuarkW4A4MXFP4 ))
338+ if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT and not self .input_layernorm_with_fp4_block_quant_gemm :
339+ logger .info (f"[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT will not be activated because { self .self_attn .__class__ .__name__ } is not using FP4 blockscale GEMM" )
340+ self .post_attention_layernorm_with_fp4_block_quant_gemm = (isinstance (self .mlp .gate_up_proj .quant_method , QuarkLinearMethod ) and hasattr (self .mlp .gate_up_proj , "scheme" ) and isinstance (self .mlp .gate_up_proj .scheme , QuarkW4A4MXFP4 ))
341+ if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT and not self .post_attention_layernorm_with_fp4_block_quant_gemm :
342+ logger .info (f"[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT will not be activated because { self .mlp .__class__ .__name__ } is not using FP4 blockscale GEMM" )
319343
320344 def forward (
321345 self ,
@@ -324,18 +348,34 @@ def forward(
324348 residual : Optional [torch .Tensor ],
325349 ) -> tuple [torch .Tensor , torch .Tensor ]:
326350 # Self Attention
327- if residual is None :
328- residual = hidden_states
329- hidden_states = self .input_layernorm (hidden_states )
351+ if VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT and self .input_layernorm_with_fp4_block_quant_gemm :
352+ weight = self .input_layernorm .weight
353+ eps = self .input_layernorm .variance_epsilon
354+ if residual is None :
355+ residual = hidden_states
356+ hidden_states_quant , hidden_states_quant_scales , _ = torch .ops .vllm .rocm_aiter_fused_rms_and_fp4_group_quant (hidden_states , weight , eps , None )
357+ else :
358+ hidden_states_quant , hidden_states_quant_scales , residual = torch .ops .vllm .rocm_aiter_fused_rms_and_fp4_group_quant (hidden_states , weight , eps , residual )
359+ hidden_states = (hidden_states_quant , hidden_states_quant_scales )
330360 else :
331- hidden_states , residual = self .input_layernorm (
332- hidden_states , residual )
361+ if residual is None :
362+ residual = hidden_states
363+ hidden_states = self .input_layernorm (hidden_states )
364+ else :
365+ hidden_states , residual = self .input_layernorm (
366+ hidden_states , residual )
333367 hidden_states = self .self_attn (positions = positions ,
334- hidden_states = hidden_states )
368+ hidden_states = hidden_states )
335369
336370 # Fully Connected
337- hidden_states , residual = self .post_attention_layernorm (
338- hidden_states , residual )
371+ if VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT and self .post_attention_layernorm_with_fp4_block_quant_gemm :
372+ weight = self .post_attention_layernorm .weight
373+ eps = self .post_attention_layernorm .variance_epsilon
374+ hidden_states_quant , hidden_states_quant_scales , residual = torch .ops .vllm .rocm_aiter_fused_rms_and_fp4_group_quant (hidden_states , weight , eps , residual )
375+ hidden_states = (hidden_states_quant , hidden_states_quant_scales )
376+ else :
377+ hidden_states , residual = self .post_attention_layernorm (
378+ hidden_states , residual )
339379 hidden_states = self .mlp (hidden_states )
340380 return hidden_states , residual
341381
0 commit comments