Skip to content

Commit 0d17388

Browse files
improved comment
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent e43d366 commit 0d17388

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -743,15 +743,15 @@ def boundary_condition(n):
743743
return subgraph_nodes
744744

745745

746-
def get_weight_shape(
747-
node: Node, dim: Optional[int] = None
748-
) -> Optional[Union[int, Tuple[int, ...]]]:
746+
def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[int, List[int]]]:
749747
"""Get the shape of the weight node."""
750748
if not is_any_lin_op(node):
751749
return None
752750
s = list(shape(extract_weight_node(node)))
751+
if len(s) == 0:
752+
return None
753753
if is_fp4_op(node):
754-
# FP4 weights are stored as half-sized FP8 tensor
754+
# FP4 weights are packed as uint8 type with 2 FP4 values per element
755755
s[-1] *= 2
756756
if dim is None:
757757
return s

0 commit comments

Comments
 (0)