Skip to content

Commit c5b0cf9

Browse files
committed
add aiter fp8 block scaled moe and w8a8 block gemm kernels
Signed-off-by: vllmellm <[email protected]>
1 parent 5c46937 commit c5b0cf9

File tree

4 files changed

+90
-8
lines changed

4 files changed

+90
-8
lines changed

Dockerfile.rocm_base

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
1212
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
1313
ARG FA_BRANCH="1a7f4dfa"
1414
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
15-
ARG AITER_BRANCH="0508c8df"
15+
ARG AITER_BRANCH="e1ec015"
1616
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
1717

1818
FROM ${BASE_IMAGE} AS base

vllm/envs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
VLLM_USE_AITER_PAGED_ATTN: bool = False
2626
VLLM_USE_AITER_LINEAR: bool = False
2727
VLLM_USE_AITER_NORM: bool = False
28+
VLLM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False
29+
VLLM_USE_AITER_W8A8_BLOCK_GEMM: bool = False
2830
RANK: int = 0
2931
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
3032
LOCAL_RANK: int = 0
@@ -322,6 +324,18 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
322324
("true", "1") and os.getenv("VLLM_USE_AITER_NORM", "True").lower() in
323325
("true", "1")),
324326

327+
# use ater fp8 block scaled moe kernel op if ater ops are enabled.
328+
"VLLM_USE_AITER_FP8_BLOCK_SCALED_MOE":
329+
lambda: (os.getenv("VLLM_USE_AITER", "False").lower() in
330+
("true", "1") and os.getenv("VLLM_USE_AITER_FP8_BLOCK_SCALED_MOE",
331+
"False").lower() in ("true", "1")),
332+
333+
# use ater w8a8 block gemm kernel op if ater ops are enabled.
334+
"VLLM_USE_AITER_W8A8_BLOCK_GEMM":
335+
lambda: (os.getenv("VLLM_USE_AITER", "False").lower() in
336+
("true", "1") and os.getenv("VLLM_USE_AITER_W8A8_BLOCK_GEMM",
337+
"False").lower() in ("true", "1")),
338+
325339
# rank of the process in the distributed setting, used to determine
326340
# the driver worker
327341
"RANK":

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
from vllm.utils import is_navi
3636

3737
if envs.VLLM_USE_AITER_MOE:
38-
from aiter.fused_moe_bf16_asm import asm_moe
38+
import aiter
39+
from aiter.fused_moe_bf16_asm import asm_moe, moe_sorting_ck
3940
from aiter.ops.shuffle import shuffle_weight
4041

4142
ACTIVATION_SCHEMES = ["static", "dynamic"]
@@ -608,6 +609,14 @@ def process_weights_after_loading(self, layer: Module) -> None:
608609
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
609610
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
610611
requires_grad=False)
612+
if envs.VLLM_USE_AITER_FP8_BLOCK_SCALED_MOE:
613+
layer.w13_weight = torch.nn.Parameter(shuffle_weight(
614+
layer.w13_weight.data),
615+
requires_grad=False)
616+
layer.w2_weight = torch.nn.Parameter(shuffle_weight(
617+
layer.w2_weight.data),
618+
requires_grad=False)
619+
611620
return
612621

613622
# If checkpoint is fp16, quantize in place.
@@ -798,6 +807,8 @@ def apply(
798807
e_score_correction_bias: Optional[torch.Tensor] = None,
799808
) -> torch.Tensor:
800809
from vllm.model_executor.layers.fused_moe import fused_experts
810+
from vllm.model_executor.layers.fused_moe.fused_moe import (
811+
per_token_group_quant_fp8)
801812

802813
topk_weights, topk_ids = FusedMoE.select_experts(
803814
hidden_states=x,
@@ -812,6 +823,52 @@ def apply(
812823
e_score_correction_bias=e_score_correction_bias,
813824
)
814825

826+
if envs.VLLM_USE_AITER_FP8_BLOCK_SCALED_MOE:
827+
w1 = layer.w13_weight
828+
w2 = layer.w2_weight
829+
w1_scale = (layer.w13_weight_scale_inv
830+
if self.block_quant else layer.w13_weight_scale)
831+
w2_scale = (layer.w2_weight_scale_inv
832+
if self.block_quant else layer.w2_weight_scale)
833+
834+
block_shape = self.quant_config.weight_block_size
835+
# The default block sizes are 128 in AITER.
836+
if block_shape is None:
837+
block_shape = [128, 128]
838+
839+
local_E = E = w1.shape[0]
840+
topk = topk_ids.shape[1]
841+
model_dim = w1.shape[-1]
842+
dtype = x.dtype
843+
scale_blk_k = block_shape[1]
844+
845+
(
846+
sorted_token_ids,
847+
sorted_weight_buf,
848+
sorted_expert_ids,
849+
num_valid_ids,
850+
out_asm,
851+
) = moe_sorting_ck(topk_ids, topk_weights, E, model_dim, dtype)
852+
a1, a1_scale = per_token_group_quant_fp8(x, scale_blk_k)
853+
aiter.fmoe_fp8_blockscale_g1u1(
854+
out_asm,
855+
a1,
856+
w1,
857+
w2,
858+
sorted_token_ids,
859+
sorted_weight_buf,
860+
sorted_expert_ids,
861+
num_valid_ids,
862+
topk,
863+
w1_scale.view(local_E, -1),
864+
w2_scale.view(local_E, -1),
865+
a1_scale.t().contiguous(),
866+
block_shape[0],
867+
block_shape[1],
868+
None,
869+
)
870+
return out_asm
871+
815872
if envs.VLLM_USE_AITER_MOE:
816873
return asm_moe(hidden_states=x,
817874
w1=layer.w13_weight,

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import triton
1111
import triton.language as tl
1212

13+
import vllm.envs as envs
1314
from vllm import _custom_ops as ops
1415
from vllm.logger import init_logger
1516
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@@ -71,12 +72,22 @@ def apply_w8a8_block_fp8_linear(
7172
q_input, x_scale = per_token_group_quant_fp8(input_2d,
7273
block_size[1],
7374
column_major_scales=False)
74-
output = w8a8_block_fp8_matmul(q_input,
75-
weight,
76-
x_scale,
77-
weight_scale,
78-
block_size,
79-
output_dtype=input.dtype)
75+
if envs.VLLM_USE_AITER_W8A8_BLOCK_GEMM:
76+
import aiter
77+
output = torch.zeros(
78+
[q_input.shape[0], weight.shape[0]],
79+
dtype=input.dtype,
80+
device=q_input.device,
81+
)
82+
aiter.gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale,
83+
output)
84+
else:
85+
output = w8a8_block_fp8_matmul(q_input,
86+
weight,
87+
x_scale,
88+
weight_scale,
89+
block_size,
90+
output_dtype=input.dtype)
8091
if bias is not None:
8192
output = output + bias
8293
return output.to(dtype=input.dtype).view(*output_shape)

0 commit comments

Comments
 (0)