File tree Expand file tree Collapse file tree 3 files changed +1
-13
lines changed
tensorrt_llm/_torch/auto_deploy Expand file tree Collapse file tree 3 files changed +1
-13
lines changed Original file line number Diff line number Diff line change @@ -1325,17 +1325,6 @@ def _shard_parameter_node(
13251325
13261326 rank , world_size = config .rank , config .world_size
13271327 allreduce_strategy = config .allreduce_strategy .name
1328- # num_users = num_users_of_weight_node(node)
1329- # if num_users > 1 or num_users == 0:
1330- # ad_logger.warning(
1331- # f"Weight node {node} has {num_users} users. This is not supported for sharding. Skipping."
1332- # )
1333- # return
1334- # # get weight and bias key
1335- # weight_key, bias_key = extract_param_names_from_node(node)
1336-
1337- # modname = weight_key.rpartition(".")[0]
1338- # submod = gm.get_submodule(modname)
13391328
13401329 # # Shard weight using the unified function (also updates the parameter)
13411330 # original_weight = gm.get_parameter(weight_key)
Original file line number Diff line number Diff line change @@ -269,7 +269,7 @@ def find_get_attr_node(weight_node: Node) -> Node:
269269
270270def num_users_of_weight_node (node : Node ) -> int :
271271 """Returns the number of users of the weight node of the given parametrized node."""
272- weight_node = extract_weight_nodes (node )[0 ]
272+ weight_node = extract_weight_nodes (node ). weights [0 ]. node
273273 return len (weight_node .users ) if weight_node is not None else 0
274274
275275
Original file line number Diff line number Diff line change @@ -117,7 +117,6 @@ def should_skip_quantization(
117117 else :
118118 if not (is_linear_op (node_or_name ) or is_bmm_op (node_or_name )):
119119 return True
120- # param_names, _ = extract_param_names_from_node(node_or_name)
121120 weight_name = extract_weight_name (node_or_name )
122121 modname = weight_name .rpartition ("." )[0 ]
123122
You can’t perform that action at this time.
0 commit comments