Skip to content

Commit e43d366

Browse files
Added proper rescaling of FP4 weights
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent 74832a1 commit e43d366

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,16 @@ def is_any_lin_op(node: Node) -> bool:
304304
return is_linear_op(node) or is_fake_quantized_linear_op(node)
305305

306306

307+
def is_fp4_op(node: Node) -> bool:
308+
return is_op(
309+
node,
310+
[
311+
torch.ops.auto_deploy.torch_quant_nvfp4_linear,
312+
torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear,
313+
],
314+
)
315+
316+
307317
def is_any_moe_op(node: Node) -> bool:
308318
return is_op(
309319
node,
@@ -739,10 +749,14 @@ def get_weight_shape(
739749
"""Get the shape of the weight node."""
740750
if not is_any_lin_op(node):
741751
return None
752+
s = list(shape(extract_weight_node(node)))
753+
if is_fp4_op(node):
754+
# FP4 weights are stored as half-sized FP8 tensor
755+
s[-1] *= 2
742756
if dim is None:
743-
return shape(extract_weight_node(node))
757+
return s
744758
else:
745-
return shape(extract_weight_node(node))[dim]
759+
return s[dim]
746760

747761

748762
def get_layer_after_linear_node(

0 commit comments

Comments
 (0)