Skip to content

Commit 9bc20e0

Browse files
code cleanup
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent f885690 commit 9bc20e0

File tree

3 files changed

+1
-13
lines changed

3 files changed

+1
-13
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,17 +1325,6 @@ def _shard_parameter_node(
13251325

13261326
rank, world_size = config.rank, config.world_size
13271327
allreduce_strategy = config.allreduce_strategy.name
1328-
# num_users = num_users_of_weight_node(node)
1329-
# if num_users > 1 or num_users == 0:
1330-
# ad_logger.warning(
1331-
# f"Weight node {node} has {num_users} users. This is not supported for sharding. Skipping."
1332-
# )
1333-
# return
1334-
# # get weight and bias key
1335-
# weight_key, bias_key = extract_param_names_from_node(node)
1336-
1337-
# modname = weight_key.rpartition(".")[0]
1338-
# submod = gm.get_submodule(modname)
13391328

13401329
# # Shard weight using the unified function (also updates the parameter)
13411330
# original_weight = gm.get_parameter(weight_key)

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def find_get_attr_node(weight_node: Node) -> Node:
269269

270270
def num_users_of_weight_node(node: Node) -> int:
271271
"""Returns the number of users of the weight node of the given parametrized node."""
272-
weight_node = extract_weight_nodes(node)[0]
272+
weight_node = extract_weight_nodes(node).weights[0].node
273273
return len(weight_node.users) if weight_node is not None else 0
274274

275275

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ def should_skip_quantization(
117117
else:
118118
if not (is_linear_op(node_or_name) or is_bmm_op(node_or_name)):
119119
return True
120-
# param_names, _ = extract_param_names_from_node(node_or_name)
121120
weight_name = extract_weight_name(node_or_name)
122121
modname = weight_name.rpartition(".")[0]
123122

0 commit comments

Comments
 (0)