Skip to content

Commit f885690

Browse files
code cleanup
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent 6de8a85 commit f885690

File tree

2 files changed

+5
-15
lines changed

2 files changed

+5
-15
lines changed

tensorrt_llm/_torch/auto_deploy/utils/_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def get_input_embeddings(model: nn.Module) -> torch.Tensor:
354354
op="call_function", target=torch.ops.aten.embedding.default
355355
)
356356
for node in found_nodes:
357-
embedding_weights.append(get_weight_tensor(gm, node))
357+
embedding_weights.append(get_weight_tensor(node))
358358

359359
if hasattr(model, "get_input_embeddings"):
360360
embedding_weights.append(model.get_input_embeddings())
@@ -400,4 +400,4 @@ def get_lm_head_node(gm: GraphModule, output_node: Optional[Node] = None) -> Nod
400400
def get_lm_head_weights(model: nn.Module) -> torch.Tensor:
401401
gm, output_node = get_output_node(model)
402402
lm_head_node = get_lm_head_node(gm, output_node)
403-
return get_weight_tensor(gm, lm_head_node)
403+
return get_weight_tensor(lm_head_node)

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -273,16 +273,6 @@ def num_users_of_weight_node(node: Node) -> int:
273273
return len(weight_node.users) if weight_node is not None else 0
274274

275275

276-
def extract_param_names_from_node(node: Node) -> Tuple[List[str], Optional[List[str]]]:
277-
"""Extracts the name of the parameter associated with the given parametrized node.
278-
279-
Args:
280-
node: node with weight parameters in the graph.
281-
"""
282-
weight_nodes, bias_nodes = extract_weight_nodes(node)
283-
return [n.node_key for n in weight_nodes], [n.node_key for n in bias_nodes]
284-
285-
286276
def get_op_overload_packet(node: Union[OpOverloadPacket, OpOverload]) -> OpOverloadPacket:
287277
"""Get the overload packet from the op overload."""
288278
if isinstance(node, OpOverloadPacket):
@@ -1011,10 +1001,10 @@ def shape(node: Node) -> Tuple[int, ...]:
10111001
return node.meta["val"].shape
10121002

10131003

1014-
def get_weight_tensor(gm: GraphModule, node: Node) -> "torch.Tensor":
1004+
def get_weight_tensor(node: Node) -> torch.Tensor:
10151005
"""Extract the weight tensor from a node within a GraphModule."""
1016-
weight_name = extract_param_names_from_node(node)[0]
1017-
return gm.get_parameter(weight_name)
1006+
weight_nodes = extract_weight_nodes(node)
1007+
return weight_nodes.weights[0].tensor
10181008

10191009

10201010
def draw_graph(gm: GraphModule, filename: str):

0 commit comments

Comments
 (0)