Skip to content

Commit 21f592a

Browse files
committed
re-enable VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT functionality for LL 70B FP4
add fused_rms_fp4_quant shuffle support add shape check for triton
1 parent 3864a30 commit 21f592a

File tree

6 files changed

+186
-42
lines changed

6 files changed

+186
-42
lines changed

vllm/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@
173173
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
174174
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
175175
VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT: bool = True
176+
VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT: bool = True
176177
VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD: bool = True
177178
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE: bool = True
178179
VLLM_ROCM_USE_AITER_TRITON_FP8_BMM: bool = True
@@ -1241,6 +1242,8 @@ def get_vllm_port() -> Optional[int]:
12411242
# Use AITER Triton fused RMSNORM + Quantization
12421243
"VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT":
12431244
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT", "1"))),
1245+
"VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT":
1246+
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT", "1"))),
12441247

12451248
# Use AITER Triton fused elementwise multiply + elementwise addtion
12461249
"VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD":

vllm/model_executor/layers/activation.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,52 @@
2626

2727
if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT:
2828
from aiter.ops.triton.activation import act_mul_and_mxfp4_quant
29+
rocm_aiter_fp4_quant_group_size = 32
30+
31+
def rocm_aiter_act_mul_and_fp4_group_quant_impl(
32+
x: torch.Tensor,
33+
) -> tuple[torch.Tensor, torch.Tensor]:
34+
M = x.shape[0]
35+
shuffle = VLLM_TRITON_FP4_GEMM_USE_ASM and (M >= 32)
36+
x_fp4, out_bs = act_mul_and_mxfp4_quant(x, activation="silu", shuffle=shuffle, scale_shuffle_padding=True)
37+
return x_fp4, out_bs
38+
39+
def rocm_aiter_act_mul_and_fp4_group_quant_fake(
40+
x: torch.Tensor,
41+
) -> tuple[torch.Tensor, torch.Tensor]:
42+
M, N = x.shape
43+
assert N % 4 == 0
44+
N_half = N // 2
45+
x_fp4 = torch.empty((M, N_half // 2), dtype=torch.uint8, device=x.device)
46+
scaleN_valid = (N_half + rocm_aiter_fp4_quant_group_size - 1) // rocm_aiter_fp4_quant_group_size
47+
scaleM = (M + 255) // 256 * 256
48+
scaleN = (scaleN_valid + 7) // 8 * 8
49+
out_bs = torch.empty((scaleM, scaleN), dtype=torch.uint8, device=x.device)
50+
return x_fp4, out_bs
51+
52+
direct_register_custom_op(
53+
op_name="rocm_aiter_act_mul_and_fp4_group_quant",
54+
op_func=rocm_aiter_act_mul_and_fp4_group_quant_impl,
55+
mutates_args=[],
56+
fake_impl=rocm_aiter_act_mul_and_fp4_group_quant_fake,
57+
dispatch_key=current_platform.dispatch_key,
58+
)
59+
2960

3061
VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT = envs.VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT
3162

3263
if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT:
33-
logger.info("[Aiter] VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT=1")
3464
from aiter.ops.triton.activation import act_mul_and_fp8_group_quant
3565
import aiter as rocm_aiter
3666
rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8
3767
rocm_aiter_fp8_quant_group_size = 128
3868

39-
def act_mul_and_fp8_group_quant_impl(
69+
def rocm_aiter_act_mul_and_fp8_group_quant_impl(
4070
x: torch.Tensor,
4171
) -> tuple[torch.Tensor, torch.Tensor]:
42-
return act_mul_and_fp8_group_quant(x, activation="silu", group_size=rocm_aiter_fp8_quant_group_size, dtype_quant=rocm_aiter_fp8_dtype)
72+
return rocm_aiter_act_mul_and_fp8_group_quant(x, activation="silu", group_size=rocm_aiter_fp8_quant_group_size, dtype_quant=rocm_aiter_fp8_dtype)
4373

44-
def act_mul_and_fp8_group_quant_fake(
74+
def rocm_aiter_act_mul_and_fp8_group_quant_fake(
4575
x: torch.Tensor,
4676
) -> tuple[torch.Tensor, torch.Tensor]:
4777
M, N = x.shape
@@ -52,10 +82,10 @@ def act_mul_and_fp8_group_quant_fake(
5282
return x_fp8, out_bs
5383

5484
direct_register_custom_op(
55-
op_name="act_mul_and_fp8_group_quant",
56-
op_func=act_mul_and_fp8_group_quant_impl,
85+
op_name="rocm_aiter_act_mul_and_fp8_group_quant",
86+
op_func=rocm_aiter_act_mul_and_fp8_group_quant_impl,
5787
mutates_args=[],
58-
fake_impl=act_mul_and_fp8_group_quant_fake,
88+
fake_impl=rocm_aiter_act_mul_and_fp8_group_quant_fake,
5989
dispatch_key=current_platform.dispatch_key,
6090
)
6191

@@ -118,9 +148,7 @@ class SiluAndMul(CustomOp):
118148
def __init__(self):
119149
super().__init__()
120150

121-
if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT:
122-
self.op = lambda x, shuffle: act_mul_and_mxfp4_quant(x, "silu", shuffle=shuffle)
123-
elif current_platform.is_cuda_alike():
151+
if current_platform.is_cuda_alike():
124152
self.op = torch.ops._C.silu_and_mul
125153
elif current_platform.is_xpu():
126154
from vllm._ipex_ops import ipex_ops
@@ -136,16 +164,11 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
136164
def forward_cuda(self,
137165
x: torch.Tensor,
138166
scale: Optional[torch.Tensor] = None) -> torch.Tensor:
139-
if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT:
140-
shuffle = VLLM_TRITON_FP4_GEMM_USE_ASM and x.shape[0] >= 32
141-
out, out_scales = self.op(x, shuffle)
142-
return out, out_scales
143-
else:
144-
d = x.shape[-1] // 2
145-
output_shape = (x.shape[:-1] + (d, ))
146-
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
147-
self.op(out, x)
148-
return out
167+
d = x.shape[-1] // 2
168+
output_shape = (x.shape[:-1] + (d, ))
169+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
170+
self.op(out, x)
171+
return out
149172

150173
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
151174
d = x.shape[-1] // 2

vllm/model_executor/layers/layernorm.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,71 @@
1111
from vllm.platforms import current_platform
1212
from vllm.utils import direct_register_custom_op
1313

14+
from vllm.logger import init_logger
15+
logger = init_logger(__name__)
16+
17+
if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
18+
VLLM_TRITON_FP4_GEMM_USE_ASM = envs.VLLM_ROCM_USE_AITER and envs.VLLM_TRITON_FP4_GEMM_USE_ASM
19+
VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT = envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT
20+
21+
if VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT:
22+
from aiter import per_1x32_f4_quant_hip
23+
from aiter.ops.triton.fused_mxfp4_quant import fused_rms_mxfp4_quant
24+
rocm_aiter_fp4_quant_group_size = 32
25+
26+
def rocm_aiter_fused_rms_and_fp4_group_quant_impl(
27+
x: torch.Tensor,
28+
weight: torch.Tensor,
29+
eps: float,
30+
residual: Optional[torch.Tensor] = None,
31+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
32+
M = x.shape[0]
33+
res = None
34+
if M <= 64:
35+
shuffle = VLLM_TRITON_FP4_GEMM_USE_ASM and (M >= 32)
36+
(x_fp4, out_bs), _, res = fused_rms_mxfp4_quant(x, weight, eps,
37+
None, None, eps,
38+
res1=residual,
39+
shuffle=shuffle,
40+
scale_shuffle_padding=True)
41+
else:
42+
shuffle = VLLM_TRITON_FP4_GEMM_USE_ASM
43+
if residual is not None:
44+
x_rms, res = torch.ops.vllm.rocm_aiter_fused_add_rms_norm(x, residual, weight, eps)
45+
else:
46+
x_rms = torch.ops.vllm.rocm_aiter_rms_norm(x, weight, eps)
47+
x_fp4, out_bs = per_1x32_f4_quant_hip(x_rms, shuffle=shuffle)
48+
49+
if res is None:
50+
return x_fp4, out_bs, x
51+
return x_fp4, out_bs, res
52+
53+
def rocm_aiter_fused_rms_and_fp4_group_quant_fake(
54+
x: torch.Tensor,
55+
weight: torch.Tensor,
56+
eps: float,
57+
residual: Optional[torch.Tensor] = None,
58+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
59+
M, N = x.shape
60+
x_fp4 = torch.empty((M, N // 2), dtype=torch.uint8, device=x.device)
61+
scaleN_valid = (N + rocm_aiter_fp4_quant_group_size - 1) // rocm_aiter_fp4_quant_group_size
62+
scaleM = (M + 255) // 256 * 256
63+
scaleN = (scaleN_valid + 7) // 8 * 8
64+
out_bs = torch.empty((scaleM, scaleN), dtype=torch.uint8, device=x.device)
65+
res = torch.empty((M, N), dtype=x.dtype, device=x.device)
66+
return x_fp4, out_bs, res
67+
68+
direct_register_custom_op(
69+
op_name="rocm_aiter_fused_rms_and_fp4_group_quant",
70+
op_func=rocm_aiter_fused_rms_and_fp4_group_quant_impl,
71+
mutates_args=[],
72+
fake_impl=rocm_aiter_fused_rms_and_fp4_group_quant_fake,
73+
dispatch_key=current_platform.dispatch_key,
74+
)
75+
else:
76+
VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT = False
77+
78+
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT=}")
1479

1580
def is_rocm_aiter_rmsnorm_enabled() -> bool:
1681
return current_platform.is_rocm() \

vllm/model_executor/layers/linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -597,8 +597,8 @@ def forward(
597597
# Matrix multiply.
598598
assert self.quant_method is not None
599599
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
600-
from vllm.model_executor.layers.quantization.quark.schemes.quark_w4a4_mxfp4 import QuarkW4A4MXFP4
601-
if isinstance(self.quant_method, Fp8LinearMethod) or isinstance(self.quant_method, QuarkW4A4MXFP4):
600+
from vllm.model_executor.layers.quantization.quark.quark import QuarkLinearMethod
601+
if isinstance(self.quant_method, Fp8LinearMethod) or isinstance(self.quant_method, QuarkLinearMethod):
602602
output_parallel = self.quant_method.apply(self, input_, bias, x_quant_scales=x_quant_scales)
603603
else:
604604
assert x_quant_scales is None, f"x_quant_scales input is not supported for {self.quant_method.__class__}"

vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525

2626
def aiter_triton_gemm_check(m, n, k):
2727
if m <= 64:
28-
return ((n == 8192 and k == 8192) or (n == 10240 and k == 8192)
29-
or (n == 57344 and k == 8192) or (n == 8192 and k == 28672))
28+
return (
29+
(n == 10240 and k == 8192) or (n == 8192 and k == 8192) or (n == 57344 and k == 8192) or (n == 8192 and k == 28672) or
30+
(n == 1280 and k == 8192) or (n == 8192 and k == 1024) or (n == 7168 and k == 8192) or (n == 8192 and k == 3584)
31+
)
3032
return False
3133

3234
def gemm_with_dynamic_quant(
@@ -67,22 +69,33 @@ def gemm_with_dynamic_quant(
6769

6870
gemm_afp4wfp4_preshuffled_weight_scales(x_q.view(torch.uint8), weight.view(torch.uint8).view(weight.shape[0] // 16, -1),
6971
x_s, weight_scale.view(torch.uint8).view(weight_scale.shape[0] // 32, -1), out_dtype, y)
72+
7073
else:
7174

7275
if x_scales is None:
7376
# use hip quant kernel for performance
7477
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
7578
else:
76-
x_q = x
77-
x_s = x_scales
79+
x_q = x.view(torch.float4_e2m1fn_x2)
80+
x_s = x_scales.view(torch.float8_e8m0fnu)
7881

7982
# 32 alignment is enough for dim0 padding of output for
8083
# gemm_a4w4 kernel
8184
y = torch.empty((M + 31) // 32 * 32,
8285
weight.shape[0],
8386
device=x_q.device,
8487
dtype=out_dtype)
85-
88+
89+
# weight = weight.view(x_q.dtype)
90+
# weight_scale = weight_scale.view(x_s.dtype)
91+
# print("fp4dtype", x_q.dtype, weight.dtype, x_s.dtype, weight_scale.dtype)
92+
93+
# gemm_a4w4(x_q,
94+
# weight,
95+
# x_s,
96+
# weight_scale,
97+
# y,
98+
# bpreshuffle=True)
8699
gemm_a4w4(x_q,
87100
weight.view(x_q.dtype),
88101
x_s,

vllm/model_executor/models/llama.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,22 @@
5656
is_pp_missing_parameter,
5757
make_empty_intermediate_tensors_factory, make_layers,
5858
maybe_prefix)
59+
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
60+
from vllm.model_executor.layers.quantization.quark.quark import QuarkLinearMethod
61+
from vllm.model_executor.layers.quantization.quark.schemes.quark_w4a4_mxfp4 import QuarkW4A4MXFP4
5962

6063
from vllm.platforms import current_platform
6164
from vllm.logger import init_logger
6265
logger = init_logger(__name__)
6366

6467
if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
65-
from vllm.model_executor.layers.activation import VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT
68+
from vllm.model_executor.layers.activation import VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT, VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT
69+
from vllm.model_executor.layers.layernorm import VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT
6670
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
6771
else:
6872
VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT = False
73+
VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT = False
74+
VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT = False
6975
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False
7076

7177
VLLM_ROCM_USE_AITER_MHA = envs.VLLM_ROCM_USE_AITER_MHA
@@ -104,15 +110,23 @@ def __init__(
104110
if hidden_act != "silu":
105111
raise ValueError(f"Unsupported activation: {hidden_act}. "
106112
"Only silu is supported for now.")
107-
self.block_quant = hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None
113+
self.block_quant = isinstance(self.down_proj.quant_method, Fp8LinearMethod) and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None
108114
if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT and not self.block_quant:
109-
logger.info("[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT will not be activated because this model is not using blocked quantization")
115+
logger.info(f"[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT will not be activated because {self.__class__.__name__} is not using FP8 blockscale GEMM")
116+
self.fp4_block_quant_gemm = (isinstance(self.down_proj.quant_method, QuarkLinearMethod) and hasattr(self.down_proj, "scheme") and isinstance(self.down_proj.scheme, QuarkW4A4MXFP4))
117+
if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT and not self.fp4_block_quant_gemm:
118+
logger.info(f"[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT will not be activated because {self.__class__.__name__} is not using FP4 blockscale GEMM")
110119
self.act_fn = SiluAndMul()
111120

112121
def forward(self, x):
113-
x, _ = self.gate_up_proj(x)
114-
if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT and self.block_quant:
115-
x = torch.ops.vllm.act_mul_and_fp8_group_quant(x)
122+
x_quant_scales = None
123+
if isinstance(x, tuple):
124+
x, x_quant_scales = x
125+
x, _ = self.gate_up_proj(x, x_quant_scales=x_quant_scales)
126+
if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT and self.fp4_block_quant_gemm:
127+
x = torch.ops.vllm.rocm_aiter_act_mul_and_fp4_group_quant(x)
128+
elif VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT and self.block_quant:
129+
x = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant(x)
116130
else:
117131
x = self.act_fn(x)
118132
x_quant_scales = None
@@ -220,7 +234,11 @@ def forward(
220234
positions: torch.Tensor,
221235
hidden_states: torch.Tensor,
222236
) -> torch.Tensor:
223-
qkv, _ = self.qkv_proj(hidden_states)
237+
hidden_states_quant = None
238+
if isinstance(hidden_states, tuple):
239+
hidden_states, hidden_states_quant = hidden_states
240+
241+
qkv, _ = self.qkv_proj(hidden_states, x_quant_scales = hidden_states_quant)
224242
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
225243

226244
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
@@ -316,6 +334,12 @@ def __init__(
316334
eps=config.rms_norm_eps)
317335
self.post_attention_layernorm = RMSNorm(config.hidden_size,
318336
eps=config.rms_norm_eps)
337+
self.input_layernorm_with_fp4_block_quant_gemm = (isinstance(self.self_attn.qkv_proj.quant_method, QuarkLinearMethod) and hasattr(self.self_attn.qkv_proj, "scheme") and isinstance(self.self_attn.qkv_proj.scheme, QuarkW4A4MXFP4))
338+
if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT and not self.input_layernorm_with_fp4_block_quant_gemm:
339+
logger.info(f"[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT will not be activated because {self.self_attn.__class__.__name__} is not using FP4 blockscale GEMM")
340+
self.post_attention_layernorm_with_fp4_block_quant_gemm = (isinstance(self.mlp.gate_up_proj.quant_method, QuarkLinearMethod) and hasattr(self.mlp.gate_up_proj, "scheme") and isinstance(self.mlp.gate_up_proj.scheme, QuarkW4A4MXFP4))
341+
if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT and not self.post_attention_layernorm_with_fp4_block_quant_gemm:
342+
logger.info(f"[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT will not be activated because {self.mlp.__class__.__name__} is not using FP4 blockscale GEMM")
319343

320344
def forward(
321345
self,
@@ -324,18 +348,34 @@ def forward(
324348
residual: Optional[torch.Tensor],
325349
) -> tuple[torch.Tensor, torch.Tensor]:
326350
# Self Attention
327-
if residual is None:
328-
residual = hidden_states
329-
hidden_states = self.input_layernorm(hidden_states)
351+
if VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT and self.input_layernorm_with_fp4_block_quant_gemm:
352+
weight = self.input_layernorm.weight
353+
eps = self.input_layernorm.variance_epsilon
354+
if residual is None:
355+
residual = hidden_states
356+
hidden_states_quant, hidden_states_quant_scales, _ = torch.ops.vllm.rocm_aiter_fused_rms_and_fp4_group_quant(hidden_states, weight, eps, None)
357+
else:
358+
hidden_states_quant, hidden_states_quant_scales, residual = torch.ops.vllm.rocm_aiter_fused_rms_and_fp4_group_quant(hidden_states, weight, eps, residual)
359+
hidden_states = (hidden_states_quant, hidden_states_quant_scales)
330360
else:
331-
hidden_states, residual = self.input_layernorm(
332-
hidden_states, residual)
361+
if residual is None:
362+
residual = hidden_states
363+
hidden_states = self.input_layernorm(hidden_states)
364+
else:
365+
hidden_states, residual = self.input_layernorm(
366+
hidden_states, residual)
333367
hidden_states = self.self_attn(positions=positions,
334-
hidden_states=hidden_states)
368+
hidden_states=hidden_states)
335369

336370
# Fully Connected
337-
hidden_states, residual = self.post_attention_layernorm(
338-
hidden_states, residual)
371+
if VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT and self.post_attention_layernorm_with_fp4_block_quant_gemm:
372+
weight = self.post_attention_layernorm.weight
373+
eps = self.post_attention_layernorm.variance_epsilon
374+
hidden_states_quant, hidden_states_quant_scales, residual = torch.ops.vllm.rocm_aiter_fused_rms_and_fp4_group_quant(hidden_states, weight, eps, residual)
375+
hidden_states = (hidden_states_quant, hidden_states_quant_scales)
376+
else:
377+
hidden_states, residual = self.post_attention_layernorm(
378+
hidden_states, residual)
339379
hidden_states = self.mlp(hidden_states)
340380
return hidden_states, residual
341381

0 commit comments

Comments
 (0)