diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 286650c77fd..ded746fe786 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -304,6 +304,16 @@ def is_any_lin_op(node: Node) -> bool: return is_linear_op(node) or is_fake_quantized_linear_op(node) +def is_fp4_op(node: Node) -> bool: + return is_op( + node, + [ + torch.ops.auto_deploy.torch_quant_nvfp4_linear, + torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear, + ], + ) + + def is_any_moe_op(node: Node) -> bool: return is_op( node, @@ -733,16 +743,20 @@ def boundary_condition(n): return subgraph_nodes -def get_weight_shape( - node: Node, dim: Optional[int] = None -) -> Optional[Union[int, Tuple[int, ...]]]: +def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[int, List[int]]]: """Get the shape of the weight node.""" if not is_any_lin_op(node): return None + s = list(shape(extract_weight_node(node))) + if len(s) == 0: + return None + if is_fp4_op(node): + # FP4 weights are packed as uint8 type with 2 FP4 values per element + s[-1] *= 2 if dim is None: - return shape(extract_weight_node(node)) + return s else: - return shape(extract_weight_node(node))[dim] + return s[dim] def get_layer_after_linear_node(