File tree Expand file tree Collapse file tree 1 file changed +16
-2
lines changed
tensorrt_llm/_torch/auto_deploy/utils Expand file tree Collapse file tree 1 file changed +16
-2
lines changed Original file line number Diff line number Diff line change @@ -304,6 +304,16 @@ def is_any_lin_op(node: Node) -> bool:
304304 return is_linear_op (node ) or is_fake_quantized_linear_op (node )
305305
306306
307+ def is_fp4_op (node : Node ) -> bool :
308+ return is_op (
309+ node ,
310+ [
311+ torch .ops .auto_deploy .torch_quant_nvfp4_linear ,
312+ torch .ops .auto_deploy .torch_fake_quant_nvfp4_linear ,
313+ ],
314+ )
315+
316+
307317def is_any_moe_op (node : Node ) -> bool :
308318 return is_op (
309319 node ,
@@ -739,10 +749,14 @@ def get_weight_shape(
739749 """Get the shape of the weight node."""
740750 if not is_any_lin_op (node ):
741751 return None
752+ s = list (shape (extract_weight_node (node )))
753+ if is_fp4_op (node ):
754+ # FP4 weights are stored as half-sized FP8 tensor
755+ s [- 1 ] *= 2
742756 if dim is None :
743- return shape ( extract_weight_node ( node ))
757+ return s
744758 else :
745- return shape ( extract_weight_node ( node )) [dim ]
759+ return s [dim ]
746760
747761
748762def get_layer_after_linear_node (
You can’t perform that action at this time.
0 commit comments