Skip to content

Commit 97b42e7

Browse files
fixed BMM weight finding
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent 3e44abd commit 97b42e7

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,16 +226,17 @@ def find_get_attr_node(weight_node: Node) -> Node:
226226

227227
if is_op(node, torch.ops.aten.bmm):
228228
# no bias for bmm
229+
weight_node = find_get_attr_node(node.args[1])
229230
return WeightNodes(
230-
[
231+
weights=[
231232
WeightNode(
232233
node=node.args[1],
233-
node_key=node.args[1].target,
234-
tensor=get_const_tensor(node.args[1].target, gm),
235-
submod=gm.get_submodule(node.args[1].target.rpartition(".")[0]),
234+
node_key=weight_node.target,
235+
tensor=get_const_tensor(weight_node.target, gm),
236+
submod=gm.get_submodule(weight_node.target.rpartition(".")[0]),
236237
)
237238
],
238-
[],
239+
biases=[],
239240
)
240241
# for other parametrized nodes, we need to find the weight node
241242
else:

0 commit comments

Comments
 (0)