Skip to content

Commit a269173

Browse files
authored
[Quantization/NVFP4] Speed up TRTLLM NVFP4 MOE weight loading and fix K/V scale loading for MLA Attn (vllm-project#25968)
Signed-off-by: Pavani Majety <[email protected]>
1 parent cd9e5b8 commit a269173

File tree

3 files changed

+77
-55
lines changed

3 files changed

+77
-55
lines changed

vllm/model_executor/layers/quantization/kv_cache.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
8686
logger.warning_once(
8787
"Checkpoint does not provide a q scaling factor. "
8888
"Setting it to k_scale. This only matters for "
89-
"the flash-attn backend.")
89+
"FP8 Attention backends (flash-attn or flashinfer).")
9090
layer._q_scale.copy_(k_scale)
9191
layer._q_scale_float = k_scale
9292

@@ -98,9 +98,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
9898
if (k_scale == 1.0 and v_scale == 1.0
9999
and "e5m2" not in layer.kv_cache_dtype):
100100
logger.warning_once(
101-
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
102-
"may cause accuracy issues. Please make sure k/v_scale "
103-
"scaling factors are available in the fp8 checkpoint.")
101+
"Using KV cache scaling factor 1.0 for fp8_e4m3. "
102+
"If this is unintended, verify that k/v_scale "
103+
"scaling factors are properly set in the checkpoint.")
104104

105105
if layer.q_scale > 0.0:
106106
q_scale = layer.q_scale

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 66 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,7 +1064,7 @@ def __init__(
10641064
self.allow_flashinfer = _nvfp4.allow_flashinfer
10651065
self.use_marlin = _nvfp4.use_marlin
10661066
self.flashinfer_moe_backend = None
1067-
1067+
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
10681068
if self.allow_flashinfer:
10691069
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
10701070
logger.info_once(
@@ -1197,19 +1197,23 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
11971197
weight_loader=weight_loader)
11981198
layer.register_parameter("w2_input_scale", w2_input_scale)
11991199

1200-
def prepare_static_weight_layouts_for_trtllm_moe(
1200+
def prepare_static_weights_for_trtllm_fp4_moe(
12011201
self,
1202-
gemm1_weights: torch.Tensor,
1203-
gemm2_weights: torch.Tensor,
1204-
gemm1_scales_linear_fp4_bytes: torch.Tensor,
1205-
gemm2_scales_linear_fp4_bytes: torch.Tensor,
1206-
hidden_size: int,
1207-
intermediate_size: int,
1208-
num_experts: int,
1209-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1202+
# args_dequant,
1203+
# args,
1204+
gemm1_weights,
1205+
gemm2_weights,
1206+
gemm1_scales_linear_fp4_bytes,
1207+
gemm2_scales_linear_fp4_bytes,
1208+
hidden_size,
1209+
intermediate_size,
1210+
num_experts,
1211+
):
1212+
from flashinfer import nvfp4_block_scale_interleave
1213+
from flashinfer.fused_moe.core import (
1214+
_maybe_get_cached_w2_permute_indices,
1215+
_maybe_get_cached_w3_w1_permute_indices)
12101216
"""Prepare quantized weights for kernel (done offline with weights)."""
1211-
from flashinfer import (reorder_rows_for_gated_act_gemm,
1212-
shuffle_matrix_a, shuffle_matrix_sf_a)
12131217
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
12141218

12151219
# Convert quantized weights to proper formats
@@ -1227,48 +1231,54 @@ def prepare_static_weight_layouts_for_trtllm_moe(
12271231
intermediate_size //
12281232
16) # fp8 scaling factors
12291233

1230-
# Reorder rows of W1 and scales for fused gated activation
1231-
gemm1_weights_fp4_interleaved = []
1232-
gemm1_scales_fp4_interleaved = []
1233-
for i in range(num_experts):
1234-
gemm1_weights_fp4_interleaved.append(
1235-
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone()))
1236-
gemm1_scales_fp4_interleaved.append(
1237-
reorder_rows_for_gated_act_gemm(
1238-
gemm1_scales_linear_fp4[i].clone()))
1239-
1240-
# Stack weights and scales for all experts
1241-
gemm1_weights_fp4_interleaved = torch.stack(
1242-
gemm1_weights_fp4_interleaved).reshape(num_experts,
1243-
2 * intermediate_size,
1244-
hidden_size // 2)
1245-
gemm1_scales_fp4_interleaved = torch.stack(
1246-
gemm1_scales_fp4_interleaved).reshape(num_experts,
1247-
2 * intermediate_size,
1248-
hidden_size // 16)
1249-
1250-
# Shuffle weights and scaling factors for transposed mma output
12511234
gemm1_weights_fp4_shuffled = []
12521235
gemm1_scales_fp4_shuffled = []
12531236
gemm2_weights_fp4_shuffled = []
12541237
gemm2_scales_fp4_shuffled = []
12551238
for i in range(num_experts):
1256-
gemm1_weights_fp4_shuffled.append(
1257-
shuffle_matrix_a(
1258-
gemm1_weights_fp4_interleaved[i].view(torch.uint8),
1259-
epilogue_tile_m))
1239+
# Calculate the permute indices for the following:
1240+
# 1. Reorder rows of W1 and scales for fused gated activation
1241+
# 2. Shuffle weights and scaling factors for transposed mma output
1242+
# for both w3_w1 and w2 weights and scale factors
1243+
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
1244+
self._cache_permute_indices,
1245+
gemm1_weights_fp4[i].view(torch.uint8),
1246+
epilogue_tile_m,
1247+
)
1248+
gemm1_weights_fp4_shuffled.append(gemm1_weights_fp4[i].view(
1249+
torch.uint8)[permute_indices.to(
1250+
gemm1_weights_fp4.device)].contiguous())
1251+
1252+
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
1253+
self._cache_permute_indices,
1254+
gemm1_scales_linear_fp4[i].view(torch.uint8),
1255+
epilogue_tile_m,
1256+
num_elts_per_sf=16,
1257+
)
12601258
gemm1_scales_fp4_shuffled.append(
1261-
shuffle_matrix_sf_a(
1262-
gemm1_scales_fp4_interleaved[i].view(torch.uint8),
1263-
epilogue_tile_m))
1264-
1265-
gemm2_weights_fp4_shuffled.append(
1266-
shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8),
1267-
epilogue_tile_m))
1259+
nvfp4_block_scale_interleave(gemm1_scales_linear_fp4[i].view(
1260+
torch.uint8)[permute_sf_indices.to(
1261+
gemm1_scales_linear_fp4.device)].contiguous()))
1262+
1263+
permute_indices = _maybe_get_cached_w2_permute_indices(
1264+
self._cache_permute_indices,
1265+
gemm2_weights_fp4[i].view(torch.uint8),
1266+
epilogue_tile_m,
1267+
)
1268+
gemm2_weights_fp4_shuffled.append(gemm2_weights_fp4[i].view(
1269+
torch.uint8)[permute_indices.to(
1270+
gemm2_weights_fp4.device)].contiguous())
1271+
1272+
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
1273+
self._cache_permute_indices,
1274+
gemm2_scales_linear_fp4[i].view(torch.uint8),
1275+
epilogue_tile_m,
1276+
num_elts_per_sf=16,
1277+
)
12681278
gemm2_scales_fp4_shuffled.append(
1269-
shuffle_matrix_sf_a(
1270-
gemm2_scales_linear_fp4[i].view(torch.uint8),
1271-
epilogue_tile_m))
1279+
nvfp4_block_scale_interleave(gemm2_scales_linear_fp4[i].view(
1280+
torch.uint8)[permute_sf_indices.to(
1281+
gemm2_scales_linear_fp4.device)].contiguous()))
12721282

12731283
# Stack weights for all experts
12741284
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
@@ -1283,8 +1293,12 @@ def prepare_static_weight_layouts_for_trtllm_moe(
12831293
torch.stack(gemm2_scales_fp4_shuffled).view(
12841294
torch.float8_e4m3fn).reshape(num_experts, hidden_size,
12851295
intermediate_size // 16))
1286-
return (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
1287-
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled)
1296+
return (
1297+
gemm1_weights_fp4_shuffled,
1298+
gemm1_scales_fp4_shuffled,
1299+
gemm2_weights_fp4_shuffled,
1300+
gemm2_scales_fp4_shuffled,
1301+
)
12881302

12891303
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
12901304
# GEMM 1 processing
@@ -1334,9 +1348,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
13341348
if self.allow_flashinfer and \
13351349
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
13361350
# Prepare static weights for TRT-LLM kernel
1351+
# alternate: prepare_static_weight_layouts_for_trtllm_moe
13371352
(gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
13381353
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled
1339-
) = self.prepare_static_weight_layouts_for_trtllm_moe(
1354+
) = self.prepare_static_weights_for_trtllm_fp4_moe(
13401355
layer.w13_weight,
13411356
layer.w2_weight,
13421357
layer.w13_weight_scale,
@@ -1345,6 +1360,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
13451360
layer.w13_weight.size(-2) // 2, # intermediate_size
13461361
layer.w13_weight.size(0), # num_experts
13471362
)
1363+
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
13481364

13491365
layer.gemm1_weights_fp4_shuffled = Parameter(
13501366
gemm1_weights_fp4_shuffled, requires_grad=False)

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1003,12 +1003,18 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
10031003
return None
10041004
return remapped_name
10051005

1006+
if any("mla_attn" in key for key in params_dict):
1007+
attn_str = "mla_attn.mla_attn"
1008+
logger.debug_once(f"Found mla_attn with k_scale and v_scale in "
1009+
f"the checkpoint, using {attn_str} as attn_str")
1010+
else:
1011+
attn_str = "attn"
10061012
# Define scale name mapping patterns in order of precedence
10071013
scale_mapping_patterns = [
10081014
# ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale ->
10091015
# .self_attn.attn.{k,v}_scale
10101016
(r"\.self_attn\.([kv])_proj\.([kv])_scale$",
1011-
r".self_attn.attn.\2_scale"),
1017+
rf".self_attn.{attn_str}.\2_scale"),
10121018
# QKV proj format: .self_attn.qkv_proj.{k,v}_scale ->
10131019
# .self_attn.attn.{k,v}_scale
10141020
(r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"),

0 commit comments

Comments
 (0)