Skip to content

Commit 15618ee

Browse files
committed
address review comments
Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
1 parent 2a8b6c2 commit 15618ee

File tree

3 files changed

+41
-47
lines changed

3 files changed

+41
-47
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ def trtllm_quant_nvfp4_moe_fused(
306306
)
307307
hidden_size_needs_padding = hidden_size % TRTLLM_NVFP4_COLUMN_SIZE != 0
308308
if inter_size_needs_padding or hidden_size_needs_padding:
309+
assert False, "See https://github.com/NVIDIA/TensorRT-LLM/issues/10331"
309310
# fc1_expert_weights_fp4: [E, I, H] or [E, 2*I, H]
310311
fc1_padded = fc1_expert_weights_fp4.new_zeros(
311312
fc1_expert_weights_fp4.size(0),
@@ -319,6 +320,11 @@ def trtllm_quant_nvfp4_moe_fused(
319320
fc2_padded = fc2_expert_weights_fp4.new_zeros(
320321
n_experts, hidden_size_padded, inter_size_padded // FP4_PER_UINT8
321322
)
323+
324+
assert inter_size % NVFP4_BLOCK_SIZE == 0, (
325+
f"inter_size {inter_size} must be divisible by {NVFP4_BLOCK_SIZE}"
326+
)
327+
322328
fc2_padded[:, :, : inter_size // FP4_PER_UINT8] = fc2_expert_weights_fp4
323329
fc2_expert_weights_fp4 = fc2_padded
324330

@@ -334,17 +340,20 @@ def trtllm_quant_nvfp4_moe_fused(
334340
# https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015
335341
quant_scales = [
336342
fc1_act_global_scale, # torch.float32; [E] or scalar
337-
fc1_weight_blockscale_fp8.view(torch.int32),
343+
fc1_weight_blockscale_fp8.view(
344+
torch.int32
345+
), # 4 FP8 as packed int32; [E, I*2, H / 16 / 4] or [E, I, H / 16 / 4]
338346
fc1_alpha, # torch.float32; [E]
339347
fc2_act_global_scale, # torch.float32; [E] or scalar
340-
fc2_weight_blockscale_fp8.view(torch.int32),
348+
fc2_weight_blockscale_fp8.view(torch.int32), # 4 FP8 as packed int32; [E, H, I / 16 / 4]
341349
fc2_alpha, # torch.float32; [E]
342350
]
343351

344352
trtllm_output = torch.ops.trtllm.fused_moe(
345-
x_q_fp4,
346-
selected_experts.to(torch.int),
353+
x_q_fp4.view(torch.long),
354+
selected_experts.to(torch.int32),
347355
routing_weights.to(torch.float32),
356+
# Groups of 16 FP4 weight elements are packed as a single int64 element (see isNvfp4Quant in moeOp.cpp)
348357
fc1_expert_weights=fc1_expert_weights_fp4.view(torch.long),
349358
fc1_expert_biases=None,
350359
fc2_expert_weights=fc2_expert_weights_fp4.view(torch.long),

tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py

Lines changed: 25 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1613,10 +1613,13 @@ def _extract_op_args(node):
16131613
"is_gated_mlp",
16141614
)
16151615

1616-
def _stack(param_list, dim=0):
1617-
return torch.stack(
1618-
[get_param_or_buffer(element.target) for element in param_list], dim=dim
1619-
).contiguous()
1616+
def _stack(param_list, dim=0, device=None, dtype=None):
1617+
if param_list:
1618+
return torch.stack(
1619+
[get_param_or_buffer(element.target) for element in param_list], dim=dim
1620+
).contiguous()
1621+
else:
1622+
return torch.empty(0, device=device, dtype=dtype)
16201623

16211624
def _prepare_args_cutlass_format_nvfp4():
16221625
if is_gated_mlp:
@@ -1627,9 +1630,15 @@ def _prepare_args_cutlass_format_nvfp4():
16271630
fc1_act_scale = torch.cat(
16281631
[w3_input_scale_stacked, w1_input_scale_stacked], dim=1
16291632
).contiguous()
1633+
fc1_alpha_stacked = torch.cat([w3_alpha_stacked, w1_alpha_stacked], dim=1).contiguous()
1634+
fc1_weight_blockscale_fp8_stacked = torch.cat(
1635+
[w3_weight_blockscale_fp8_stacked, w1_weight_blockscale_fp8_stacked], dim=1
1636+
).contiguous()
16301637
else:
16311638
fc1_expert_weights = w1_stacked
16321639
fc1_act_scale = w1_input_scale_stacked
1640+
fc1_alpha_stacked = w1_alpha_stacked
1641+
fc1_weight_blockscale_fp8_stacked = w1_weight_blockscale_fp8_stacked
16331642

16341643
fc2_expert_weights = w2_stacked
16351644
fc2_act_scale = w2_input_scale_stacked
@@ -1651,11 +1660,13 @@ def _prepare_args_cutlass_format_nvfp4():
16511660
weight_dtype = torch.float8_e4m3fn
16521661
_register_parameter(gm, new_key_fc1_expert_weights, fc1_expert_weights.to(weight_dtype))
16531662
_register_parameter(gm, new_key_fc2_expert_weights, fc2_expert_weights.to(weight_dtype))
1654-
_register_parameter(gm, new_key_fc1_weight_blockscale_fp8, w1_weight_blockscale_fp8_stacked)
1663+
_register_parameter(
1664+
gm, new_key_fc1_weight_blockscale_fp8, fc1_weight_blockscale_fp8_stacked
1665+
)
16551666
_register_parameter(gm, new_key_fc2_weight_blockscale_fp8, w2_weight_blockscale_fp8_stacked)
16561667
_register_parameter(gm, new_key_fc1_act_scale, fc1_act_scale)
16571668
_register_parameter(gm, new_key_fc2_act_scale, fc2_act_scale)
1658-
_register_parameter(gm, new_key_fc1_alpha, w1_alpha_stacked)
1669+
_register_parameter(gm, new_key_fc1_alpha, fc1_alpha_stacked)
16591670
_register_parameter(gm, new_key_fc2_alpha, w2_alpha_stacked)
16601671

16611672
with graph.inserting_before(node):
@@ -1705,50 +1716,23 @@ def _prepare_args_cutlass_format_nvfp4():
17051716
# Stack the actual tensor values (fast, like in quantize_moe.py)
17061717
w1_stacked = _stack(w1_list, dim=0)
17071718
w2_stacked = _stack(w2_list, dim=0)
1708-
w3_stacked = (
1709-
_stack(w3_list, dim=0)
1710-
if w3_list
1711-
else torch.empty(0, device=w1_stacked.device, dtype=w1_stacked.dtype)
1712-
)
1719+
device, dtype = (w1_stacked.device, w1_stacked.dtype)
1720+
w3_stacked = _stack(w3_list, dim=0, device=device, dtype=dtype)
17131721

17141722
# Scales are buffers, not parameters
17151723
w1_input_scale_stacked = _stack(w1_input_scale, dim=0)
17161724
w2_input_scale_stacked = _stack(w2_input_scale, dim=0)
1717-
w3_input_scale_stacked = (
1718-
_stack(w3_input_scale, dim=0)
1719-
if w3_input_scale
1720-
else torch.empty(
1721-
0, device=w1_input_scale_stacked.device, dtype=w1_input_scale_stacked.dtype
1722-
)
1723-
)
1724-
# assert torch.all(w1_input_scale_stacked[0] == w1_input_scale_stacked), (
1725-
# "All w1 scales should have the same value."
1726-
# )
1727-
# assert torch.all(w2_input_scale_stacked[0] == w2_input_scale_stacked), (
1728-
# "All w2 scales should have the same value."
1729-
# )
1725+
w3_input_scale_stacked = _stack(w3_input_scale, dim=0, device=device, dtype=dtype)
17301726

17311727
w1_weight_blockscale_fp8_stacked = _stack(w1_weight_scale, dim=0).to(torch.float8_e4m3fn)
17321728
w2_weight_blockscale_fp8_stacked = _stack(w2_weight_scale, dim=0).to(torch.float8_e4m3fn)
1733-
# w3_weight_blockscale_fp8_stacked = (
1734-
# (
1735-
# _stack(w3_weight_scale, dim=0)
1736-
# if w3_weight_scale
1737-
# else torch.empty(
1738-
# 0,
1739-
# device=w1_weight_blockscale_fp8_stacked.device,
1740-
# dtype=w1_weight_blockscale_fp8_stacked.dtype,
1741-
# )
1742-
# )
1743-
# .to(torch.float8_e4m3fn)
1744-
# .contiguous()
1745-
# )
1746-
1747-
###
1729+
w3_weight_blockscale_fp8_stacked = _stack(
1730+
w3_weight_scale, dim=0, device=device, dtype=dtype
1731+
).to(torch.float8_e4m3fn)
1732+
17481733
w1_alpha_stacked = _stack(w1_alpha, dim=0)
17491734
w2_alpha_stacked = _stack(w2_alpha, dim=0)
1750-
# w3_alpha_stacked = _stack(w3_alpha, dim=0)
1751-
###
1735+
w3_alpha_stacked = _stack(w3_alpha, dim=0, device=device, dtype=dtype)
17521736

17531737
args = _prepare_args_cutlass_format_nvfp4()
17541738

@@ -1770,7 +1754,6 @@ def _prepare_args_cutlass_format_nvfp4():
17701754
# will remove the parameters/buffers that are no longer referenced
17711755
gm.graph.eliminate_dead_code()
17721756
gm.delete_all_unused_submodules()
1773-
17741757
return fused_key_counter
17751758

17761759

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,9 @@ def test_trtllm_fused_moe_nvfp4(
584584
):
585585
# Skip known failing configuration
586586
if activation_func == ActivationType.Relu2 and intermediate_size == 1856:
587-
pytest.skip("test fails for Relu2 with intermediate_size=1856")
587+
pytest.skip(
588+
"test fails for Relu2 with intermediate_size=1856; see https://github.com/NVIDIA/TensorRT-LLM/issues/10331"
589+
)
588590

589591
# In the code below:
590592
# sf := block scale factors for NVFP4

0 commit comments

Comments
 (0)