Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT: bool = True
VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT: bool = True
VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD: bool = True
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE: bool = True
VLLM_ROCM_USE_AITER_TRITON_FP8_BMM: bool = True
Expand Down Expand Up @@ -1241,11 +1242,13 @@
# Use AITER Triton fused RMSNORM + Quantization
"VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT", "1"))),
"VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT":

Check failure on line 1245 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1245:81: E501 Line too long (92 > 80)
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT", "1"))),

Check failure on line 1247 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1247:81: E501 Line too long (92 > 80)
# Use AITER Triton fused elementwise multiply + elementwise addtion
"VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD", "1"))),

Check failure on line 1251 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1251:81: E501 Line too long (82 > 80)
# Use AITER Triton fused rope + zeros + reshape_and_cache
"VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE", "1"))),
Expand Down
63 changes: 43 additions & 20 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,52 @@

if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT:
from aiter.ops.triton.activation import act_mul_and_mxfp4_quant
rocm_aiter_fp4_quant_group_size = 32

def rocm_aiter_act_mul_and_fp4_group_quant_impl(
x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
M = x.shape[0]
shuffle = VLLM_TRITON_FP4_GEMM_USE_ASM and (M >= 32)
x_fp4, out_bs = act_mul_and_mxfp4_quant(x, activation="silu", shuffle=shuffle, scale_shuffle_padding=True)
return x_fp4, out_bs

def rocm_aiter_act_mul_and_fp4_group_quant_fake(
x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
assert N % 4 == 0
N_half = N // 2
x_fp4 = torch.empty((M, N_half // 2), dtype=torch.uint8, device=x.device)
scaleN_valid = (N_half + rocm_aiter_fp4_quant_group_size - 1) // rocm_aiter_fp4_quant_group_size
scaleM = (M + 255) // 256 * 256
scaleN = (scaleN_valid + 7) // 8 * 8
out_bs = torch.empty((scaleM, scaleN), dtype=torch.uint8, device=x.device)
return x_fp4, out_bs

direct_register_custom_op(
op_name="rocm_aiter_act_mul_and_fp4_group_quant",
op_func=rocm_aiter_act_mul_and_fp4_group_quant_impl,
mutates_args=[],
fake_impl=rocm_aiter_act_mul_and_fp4_group_quant_fake,
dispatch_key=current_platform.dispatch_key,
)


VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT = envs.VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT

if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT:
logger.info("[Aiter] VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT=1")
from aiter.ops.triton.activation import act_mul_and_fp8_group_quant
import aiter as rocm_aiter
rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8
rocm_aiter_fp8_quant_group_size = 128

def act_mul_and_fp8_group_quant_impl(
def rocm_aiter_act_mul_and_fp8_group_quant_impl(
x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return act_mul_and_fp8_group_quant(x, activation="silu", group_size=rocm_aiter_fp8_quant_group_size, dtype_quant=rocm_aiter_fp8_dtype)
return rocm_aiter_act_mul_and_fp8_group_quant(x, activation="silu", group_size=rocm_aiter_fp8_quant_group_size, dtype_quant=rocm_aiter_fp8_dtype)

def act_mul_and_fp8_group_quant_fake(
def rocm_aiter_act_mul_and_fp8_group_quant_fake(
x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
Expand All @@ -52,10 +82,10 @@ def act_mul_and_fp8_group_quant_fake(
return x_fp8, out_bs

direct_register_custom_op(
op_name="act_mul_and_fp8_group_quant",
op_func=act_mul_and_fp8_group_quant_impl,
op_name="rocm_aiter_act_mul_and_fp8_group_quant",
op_func=rocm_aiter_act_mul_and_fp8_group_quant_impl,
mutates_args=[],
fake_impl=act_mul_and_fp8_group_quant_fake,
fake_impl=rocm_aiter_act_mul_and_fp8_group_quant_fake,
dispatch_key=current_platform.dispatch_key,
)

Expand Down Expand Up @@ -118,9 +148,7 @@ class SiluAndMul(CustomOp):
def __init__(self):
super().__init__()

if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT:
self.op = lambda x, shuffle: act_mul_and_mxfp4_quant(x, "silu", shuffle=shuffle)
elif current_platform.is_cuda_alike():
if current_platform.is_cuda_alike():
self.op = torch.ops._C.silu_and_mul
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
Expand All @@ -136,16 +164,11 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
def forward_cuda(self,
x: torch.Tensor,
scale: Optional[torch.Tensor] = None) -> torch.Tensor:
if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT:
shuffle = VLLM_TRITON_FP4_GEMM_USE_ASM and x.shape[0] >= 32
out, out_scales = self.op(x, shuffle)
return out, out_scales
else:
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
self.op(out, x)
return out
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
self.op(out, x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
Expand Down
65 changes: 65 additions & 0 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,71 @@
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op

from vllm.logger import init_logger
logger = init_logger(__name__)

if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
VLLM_TRITON_FP4_GEMM_USE_ASM = envs.VLLM_ROCM_USE_AITER and envs.VLLM_TRITON_FP4_GEMM_USE_ASM
VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT = envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT

if VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT:
from aiter import per_1x32_f4_quant_hip
from aiter.ops.triton.fused_mxfp4_quant import fused_rms_mxfp4_quant
rocm_aiter_fp4_quant_group_size = 32

def rocm_aiter_fused_rms_and_fp4_group_quant_impl(
x: torch.Tensor,
weight: torch.Tensor,
eps: float,
residual: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
M = x.shape[0]
res = None
if M <= 64:
shuffle = VLLM_TRITON_FP4_GEMM_USE_ASM and (M >= 32)
(x_fp4, out_bs), _, res = fused_rms_mxfp4_quant(x, weight, eps,
None, None, eps,
res1=residual,
shuffle=shuffle,
scale_shuffle_padding=True)
else:
shuffle = VLLM_TRITON_FP4_GEMM_USE_ASM
if residual is not None:
x_rms, res = torch.ops.vllm.rocm_aiter_fused_add_rms_norm(x, residual, weight, eps)
else:
x_rms = torch.ops.vllm.rocm_aiter_rms_norm(x, weight, eps)
x_fp4, out_bs = per_1x32_f4_quant_hip(x_rms, shuffle=shuffle)

if res is None:
return x_fp4, out_bs, x
return x_fp4, out_bs, res

def rocm_aiter_fused_rms_and_fp4_group_quant_fake(
x: torch.Tensor,
weight: torch.Tensor,
eps: float,
residual: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
M, N = x.shape
x_fp4 = torch.empty((M, N // 2), dtype=torch.uint8, device=x.device)
scaleN_valid = (N + rocm_aiter_fp4_quant_group_size - 1) // rocm_aiter_fp4_quant_group_size
scaleM = (M + 255) // 256 * 256
scaleN = (scaleN_valid + 7) // 8 * 8
out_bs = torch.empty((scaleM, scaleN), dtype=torch.uint8, device=x.device)
res = torch.empty((M, N), dtype=x.dtype, device=x.device)
return x_fp4, out_bs, res

direct_register_custom_op(
op_name="rocm_aiter_fused_rms_and_fp4_group_quant",
op_func=rocm_aiter_fused_rms_and_fp4_group_quant_impl,
mutates_args=[],
fake_impl=rocm_aiter_fused_rms_and_fp4_group_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT = False

logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT=}")

def is_rocm_aiter_rmsnorm_enabled() -> bool:
return current_platform.is_rocm() \
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,8 +597,8 @@ def forward(
# Matrix multiply.
assert self.quant_method is not None
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.quark.schemes.quark_w4a4_mxfp4 import QuarkW4A4MXFP4
if isinstance(self.quant_method, Fp8LinearMethod) or isinstance(self.quant_method, QuarkW4A4MXFP4):
from vllm.model_executor.layers.quantization.quark.quark import QuarkLinearMethod
if isinstance(self.quant_method, Fp8LinearMethod) or isinstance(self.quant_method, QuarkLinearMethod):
output_parallel = self.quant_method.apply(self, input_, bias, x_quant_scales=x_quant_scales)
else:
assert x_quant_scales is None, f"x_quant_scales input is not supported for {self.quant_method.__class__}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@

def aiter_triton_gemm_check(m, n, k):
if m <= 64:
return ((n == 8192 and k == 8192) or (n == 10240 and k == 8192)
or (n == 57344 and k == 8192) or (n == 8192 and k == 28672))
return (
(n == 10240 and k == 8192) or (n == 8192 and k == 8192) or (n == 57344 and k == 8192) or (n == 8192 and k == 28672) or
(n == 1280 and k == 8192) or (n == 8192 and k == 1024) or (n == 7168 and k == 8192) or (n == 8192 and k == 3584)
)
return False

def gemm_with_dynamic_quant(
Expand Down Expand Up @@ -67,22 +69,33 @@ def gemm_with_dynamic_quant(

gemm_afp4wfp4_preshuffled_weight_scales(x_q.view(torch.uint8), weight.view(torch.uint8).view(weight.shape[0] // 16, -1),
x_s, weight_scale.view(torch.uint8).view(weight_scale.shape[0] // 32, -1), out_dtype, y)

else:

if x_scales is None:
# use hip quant kernel for performance
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
else:
x_q = x
x_s = x_scales
x_q = x.view(torch.float4_e2m1fn_x2)
x_s = x_scales.view(torch.float8_e8m0fnu)

# 32 alignment is enough for dim0 padding of output for
# gemm_a4w4 kernel
y = torch.empty((M + 31) // 32 * 32,
weight.shape[0],
device=x_q.device,
dtype=out_dtype)


# weight = weight.view(x_q.dtype)
# weight_scale = weight_scale.view(x_s.dtype)
# print("fp4dtype", x_q.dtype, weight.dtype, x_s.dtype, weight_scale.dtype)

# gemm_a4w4(x_q,
# weight,
# x_s,
# weight_scale,
# y,
# bpreshuffle=True)
gemm_a4w4(x_q,
weight.view(x_q.dtype),
x_s,
Expand Down
70 changes: 55 additions & 15 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,22 @@
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.quark.quark import QuarkLinearMethod
from vllm.model_executor.layers.quantization.quark.schemes.quark_w4a4_mxfp4 import QuarkW4A4MXFP4

from vllm.platforms import current_platform
from vllm.logger import init_logger
logger = init_logger(__name__)

if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
from vllm.model_executor.layers.activation import VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT
from vllm.model_executor.layers.activation import VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT, VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT
from vllm.model_executor.layers.layernorm import VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
else:
VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT = False
VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT = False
VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT = False
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False

VLLM_ROCM_USE_AITER_MHA = envs.VLLM_ROCM_USE_AITER_MHA
Expand Down Expand Up @@ -104,15 +110,23 @@ def __init__(
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.block_quant = hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None
self.block_quant = isinstance(self.down_proj.quant_method, Fp8LinearMethod) and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None
if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT and not self.block_quant:
logger.info("[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT will not be activated because this model is not using blocked quantization")
logger.info(f"[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT will not be activated because {self.__class__.__name__} is not using FP8 blockscale GEMM")
self.fp4_block_quant_gemm = (isinstance(self.down_proj.quant_method, QuarkLinearMethod) and hasattr(self.down_proj, "scheme") and isinstance(self.down_proj.scheme, QuarkW4A4MXFP4))
if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT and not self.fp4_block_quant_gemm:
logger.info(f"[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT will not be activated because {self.__class__.__name__} is not using FP4 blockscale GEMM")
self.act_fn = SiluAndMul()

def forward(self, x):
x, _ = self.gate_up_proj(x)
if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT and self.block_quant:
x = torch.ops.vllm.act_mul_and_fp8_group_quant(x)
x_quant_scales = None
if isinstance(x, tuple):
x, x_quant_scales = x
x, _ = self.gate_up_proj(x, x_quant_scales=x_quant_scales)
if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT and self.fp4_block_quant_gemm:
x = torch.ops.vllm.rocm_aiter_act_mul_and_fp4_group_quant(x)
elif VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT and self.block_quant:
x = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant(x)
else:
x = self.act_fn(x)
x_quant_scales = None
Expand Down Expand Up @@ -220,7 +234,11 @@ def forward(
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
hidden_states_quant = None
if isinstance(hidden_states, tuple):
hidden_states, hidden_states_quant = hidden_states

qkv, _ = self.qkv_proj(hidden_states, x_quant_scales = hidden_states_quant)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
Expand Down Expand Up @@ -316,6 +334,12 @@ def __init__(
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.input_layernorm_with_fp4_block_quant_gemm = (isinstance(self.self_attn.qkv_proj.quant_method, QuarkLinearMethod) and hasattr(self.self_attn.qkv_proj, "scheme") and isinstance(self.self_attn.qkv_proj.scheme, QuarkW4A4MXFP4))
if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT and not self.input_layernorm_with_fp4_block_quant_gemm:
logger.info(f"[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT will not be activated because {self.self_attn.__class__.__name__} is not using FP4 blockscale GEMM")
self.post_attention_layernorm_with_fp4_block_quant_gemm = (isinstance(self.mlp.gate_up_proj.quant_method, QuarkLinearMethod) and hasattr(self.mlp.gate_up_proj, "scheme") and isinstance(self.mlp.gate_up_proj.scheme, QuarkW4A4MXFP4))
if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT and not self.post_attention_layernorm_with_fp4_block_quant_gemm:
logger.info(f"[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT will not be activated because {self.mlp.__class__.__name__} is not using FP4 blockscale GEMM")

def forward(
self,
Expand All @@ -324,18 +348,34 @@ def forward(
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT and self.input_layernorm_with_fp4_block_quant_gemm:
weight = self.input_layernorm.weight
eps = self.input_layernorm.variance_epsilon
if residual is None:
residual = hidden_states
hidden_states_quant, hidden_states_quant_scales, _ = torch.ops.vllm.rocm_aiter_fused_rms_and_fp4_group_quant(hidden_states, weight, eps, None)
else:
hidden_states_quant, hidden_states_quant_scales, residual = torch.ops.vllm.rocm_aiter_fused_rms_and_fp4_group_quant(hidden_states, weight, eps, residual)
hidden_states = (hidden_states_quant, hidden_states_quant_scales)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states)
hidden_states=hidden_states)

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT and self.post_attention_layernorm_with_fp4_block_quant_gemm:
weight = self.post_attention_layernorm.weight
eps = self.post_attention_layernorm.variance_epsilon
hidden_states_quant, hidden_states_quant_scales, residual = torch.ops.vllm.rocm_aiter_fused_rms_and_fp4_group_quant(hidden_states, weight, eps, residual)
hidden_states = (hidden_states_quant, hidden_states_quant_scales)
else:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual

Expand Down
Loading