Skip to content

Commit 1f66353

Browse files
code cleanup
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent 9bc20e0 commit 1f66353

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def _insert_quantized_linear(
137137
The state_dict is also updated to contain the sharded weights.
138138
"""
139139
weight_nodes = extract_weight_nodes(node)
140+
assert len(weight_nodes.weights) == 1, "Expected exactly one weight node"
140141
lin_weight = weight_nodes.weights[0]
141142
new_param = nn.Parameter(self.quantize_weight(lin_weight.tensor), requires_grad=False)
142143
modname, _, attrname = lin_weight.node_key.rpartition(".")

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,8 @@ def find_get_attr_node(weight_node: Node) -> Node:
231231
WeightNode(
232232
node=node.args[1],
233233
node_key=node.args[1].target,
234-
tensor=gm.get_parameter(node.args[1].target),
234+
tensor=get_const_tensor(node.args[1].target, gm),
235+
submod=gm.get_submodule(node.args[1].target.rpartition(".")[0]),
235236
)
236237
],
237238
[],

0 commit comments

Comments
 (0)