Skip to content

Commit 3e44abd

Browse files
code cleanup
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent d597dc2 commit 3e44abd

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def _insert_quantized_linear(
140140
The state_dict is also updated to contain the sharded weights.
141141
"""
142142
weight_nodes = extract_weight_nodes(node)
143+
assert len(weight_nodes.weights) == 1, "Expected exactly one weight node"
143144
lin_weight = weight_nodes.weights[0]
144145
new_param = nn.Parameter(self.quantize_weight(lin_weight.tensor), requires_grad=False)
145146
modname, _, attrname = lin_weight.node_key.rpartition(".")

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,8 @@ def find_get_attr_node(weight_node: Node) -> Node:
231231
WeightNode(
232232
node=node.args[1],
233233
node_key=node.args[1].target,
234-
tensor=gm.get_parameter(node.args[1].target),
234+
tensor=get_const_tensor(node.args[1].target, gm),
235+
submod=gm.get_submodule(node.args[1].target.rpartition(".")[0]),
235236
)
236237
],
237238
[],

0 commit comments

Comments
 (0)