From e43d3662646c0af8e36a3c2eecf7f582364e8f87 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Thu, 1 Jan 2026 05:44:05 -0800 Subject: [PATCH 1/2] Added proper rescaling of FP4 weights Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../_torch/auto_deploy/utils/node_utils.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 286650c77fd..fdbe028f9a2 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -304,6 +304,16 @@ def is_any_lin_op(node: Node) -> bool: return is_linear_op(node) or is_fake_quantized_linear_op(node) +def is_fp4_op(node: Node) -> bool: + return is_op( + node, + [ + torch.ops.auto_deploy.torch_quant_nvfp4_linear, + torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear, + ], + ) + + def is_any_moe_op(node: Node) -> bool: return is_op( node, @@ -739,10 +749,14 @@ def get_weight_shape( """Get the shape of the weight node.""" if not is_any_lin_op(node): return None + s = list(shape(extract_weight_node(node))) + if is_fp4_op(node): + # FP4 weights are stored as half-sized FP8 tensor + s[-1] *= 2 if dim is None: - return shape(extract_weight_node(node)) + return s else: - return shape(extract_weight_node(node))[dim] + return s[dim] def get_layer_after_linear_node( From 0d173885ef41359e1154177ec970a324c40e17e2 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Fri, 2 Jan 2026 11:05:05 -0800 Subject: [PATCH 2/2] improved comment Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index fdbe028f9a2..ded746fe786 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -743,15 +743,15 @@ def boundary_condition(n): return subgraph_nodes -def get_weight_shape( - node: Node, dim: Optional[int] = None -) -> Optional[Union[int, Tuple[int, ...]]]: +def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[int, List[int]]]: """Get the shape of the weight node.""" if not is_any_lin_op(node): return None s = list(shape(extract_weight_node(node))) + if len(s) == 0: + return None if is_fp4_op(node): - # FP4 weights are stored as half-sized FP8 tensor + # FP4 weights are packed as uint8 type with 2 FP4 values per element s[-1] *= 2 if dim is None: return s