Skip to content

Commit d90a8e5

Browse files
[TRTLLM-10673][feat] Improved layer classification for sharding (NVIDIA#10718)
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
1 parent 925d911 commit d90a8e5

File tree

3 files changed

+375
-127
lines changed

3 files changed

+375
-127
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from ...models.factory import ModelFactory
1818
from ...shim.interface import CachedSequenceInterface
1919
from ...utils.node_utils import (
20-
extract_weight_nodes,
2120
get_quantization_params_from_linear_node,
21+
get_weight_info,
2222
is_bmm_op,
2323
is_linear_op,
2424
)
@@ -141,9 +141,10 @@ def _insert_quantized_linear(
141141
142142
The state_dict is also updated to contain the sharded weights.
143143
"""
144-
weight_nodes = extract_weight_nodes(node)
145-
assert len(weight_nodes.weights) == 1, "Expected exactly one weight node"
146-
lin_weight = weight_nodes.weights[0]
144+
lin_weight = get_weight_info(node)
145+
if lin_weight is None:
146+
raise ValueError(f"Linear node {node.name} has no weight")
147+
147148
new_param = nn.Parameter(self.quantize_weight(lin_weight.tensor), requires_grad=False)
148149
modname, _, attrname = lin_weight.node_key.rpartition(".")
149150

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 45 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,10 @@
4040
LayerType,
4141
bfs,
4242
extract_weight_name,
43-
extract_weight_nodes,
4443
filtered_nodes,
4544
get_all_layer_subgraphs,
45+
get_all_weight_infos,
4646
get_all_weights_in_subgraph,
47-
get_layer_after_linear_node,
4847
is_any_attention_op,
4948
is_any_lin_op,
5049
is_any_moe_op,
@@ -1296,6 +1295,11 @@ def _shard_parameter_node(
12961295
rank, world_size = config.rank, config.world_size
12971296
allreduce_strategy = config.allreduce_strategy.name
12981297

1298+
if "sharded" in node.meta and node.meta["sharded"]:
1299+
# Node was already sharded, skip
1300+
return
1301+
node.meta["sharded"] = True
1302+
12991303
num_users = num_users_of_weight_node(node)
13001304
if num_users > 1 or num_users == 0:
13011305
ad_logger.warning(
@@ -1304,12 +1308,17 @@ def _shard_parameter_node(
13041308
return
13051309

13061310
# Shard weight using the unified function (also updates the parameter)
1307-
weight_nodes = extract_weight_nodes(node)
1308-
for weight_node in weight_nodes.weights:
1311+
all_weight_infos = get_all_weight_infos(node)
1312+
# Parametrized nodes must have at least one weight (for debugging)
1313+
assert len(all_weight_infos.weights) > 0, (
1314+
f"Node {node.name} has no weights - weight mapping may be incorrect"
1315+
)
1316+
1317+
for weight_info in all_weight_infos.weights:
13091318
_, weight_new_shape = shard_weight_tensor(
13101319
gm=gm,
1311-
weight_tensor=weight_node.tensor,
1312-
param_key=weight_node.node_key,
1320+
weight_tensor=weight_info.tensor,
1321+
param_key=weight_info.node_key,
13131322
dim=dim,
13141323
rank=rank,
13151324
world_size=world_size,
@@ -1319,40 +1328,40 @@ def _shard_parameter_node(
13191328
if quantization_cb is not None:
13201329
quantization_cb(
13211330
gm=gm,
1322-
submod=weight_node.submod,
1331+
submod=weight_info.submod,
13231332
node=node,
1324-
weight_key=weight_node.node_key,
1333+
weight_key=weight_info.node_key,
13251334
weight_new_shape=weight_new_shape,
13261335
dim=dim,
13271336
rank=rank,
13281337
world_size=world_size,
13291338
)
13301339

1331-
for bias_node in weight_nodes.biases:
1340+
for bias_info in all_weight_infos.biases:
13321341
if dim == 0:
13331342
# update bias for dim 0 --> we can handle it like the weight
13341343
shard_weight_tensor(
13351344
gm=gm,
1336-
weight_tensor=bias_node.tensor,
1337-
param_key=bias_node.node_key,
1345+
weight_tensor=bias_info.tensor,
1346+
param_key=bias_info.node_key,
13381347
dim=dim,
13391348
rank=rank,
13401349
world_size=world_size,
13411350
min_local_shape=min_local_shape,
13421351
fused_weight_dims=fused_weight_dims,
13431352
)
1344-
elif bias_node is not None and rank != world_size - 1:
1353+
elif rank != world_size - 1:
13451354
# update the bias for dim 1 --> in this case only the last rank gets the bias to avoid
13461355
# double counting it. For all other we will delete the bias.
13471356
args = list(node.args)
13481357
node_bias = args[2]
13491358
args[2] = None
13501359
node.args = tuple(args)
13511360
gm.graph.erase_node(node_bias)
1352-
bias_param_name = bias_node.node_key.rpartition(".")[-1]
1353-
setattr(bias_node.submod, bias_param_name, None)
1361+
bias_param_name = bias_info.node_key.rpartition(".")[-1]
1362+
setattr(bias_info.submod, bias_param_name, None)
13541363
gm._register_load_state_dict_pre_hook(
1355-
partial(_load_hook_remove, param_key=bias_node.node_key)
1364+
partial(_load_hook_remove, param_key=bias_info.node_key)
13561365
)
13571366

13581367
# # # column shard with no gather: the output is sharded
@@ -2295,47 +2304,37 @@ def detect_sharding_from_config(
22952304
raise ValueError(f"Unsupported sharding source: {source}")
22962305
tp_plan = config["tp_plan"]
22972306

2298-
# If the node is inside the attention module, we need to set min_local_shape to the
2299-
# head_dim - otherwise, we would risk splitting the heads into smaller shards.
2300-
# TODO: is there a better way to check if we are in attention module?
2301-
attn_names = [
2302-
"attention",
2303-
"Attention",
2304-
"attn",
2305-
"Attn",
2306-
"q_proj",
2307-
"k_proj",
2308-
"v_proj",
2309-
"o_proj",
2310-
]
2311-
23122307
num_shards = 0
23132308
num_simple_shards = 0
23142309
num_row_col_shards = 0
23152310
num_attention_shards = 0
23162311
num_ssm_shards = 0
2317-
head_dim = -1
23182312
linear_nodes = list(filtered_nodes(gm.graph.nodes, is_any_lin_op))
23192313

2314+
# use layer_subgraphs to determine the layer_type
2315+
# and check the validity of the sharding transform
2316+
layer_subgraphs, unprocessed_linear_nodes = get_all_layer_subgraphs(gm)
2317+
23202318
for lin_node in linear_nodes:
23212319
# use node's weight name to get the module name
23222320
weight_name = extract_weight_name(lin_node)
2323-
2324-
if any(attn_name in weight_name for attn_name in attn_names):
2325-
# find the next attention node and infer the head_dim
2326-
next_attention_node, _ = bfs(
2327-
lin_node, is_any_attention_op, attr_next="users", include_root=False
2328-
)
2329-
if next_attention_node is None:
2330-
# this is the last attention node in the graph. Take the previously found head_dim
2331-
assert head_dim != -1, "Head dim not found for the last attention node"
2332-
else:
2333-
head_dim = shape(next_attention_node)[-1]
2334-
min_local_shape = head_dim
2335-
layer_type = LayerType.ATTENTION
2321+
# get the parent layer_subgraph
2322+
layer_subgraph = [
2323+
layer
2324+
for layer in layer_subgraphs
2325+
if lin_node in layer.opening_nodes or lin_node == layer.terminating_node
2326+
]
2327+
if len(layer_subgraph) == 1:
2328+
layer_subgraph = layer_subgraph[0]
2329+
layer_type = layer_subgraph.layer_type
23362330
else:
2337-
min_local_shape = 1
2338-
layer_type = LayerType.MLP
2331+
if lin_node in unprocessed_linear_nodes:
2332+
layer_type = LayerType.UNKNOWN
2333+
else:
2334+
ad_logger.warning(
2335+
f"Failed to find the parent layer_subgraph for linear node {lin_node}. "
2336+
f"May result in incorrect sharding."
2337+
)
23392338

23402339
# use regex to find if module_name matches any of the keys in sharding_config
23412340
for key in tp_plan.keys():
@@ -2349,11 +2348,6 @@ def detect_sharding_from_config(
23492348
# we have a match. Get the config for this layer
23502349
config = tp_plan[key]
23512350

2352-
if config in ["colwise", "mamba"]:
2353-
cur_node_index = linear_nodes.index(lin_node)
2354-
layer_subgraph = get_layer_after_linear_node(
2355-
linear_nodes, [cur_node_index - 1], enforce_strict_linear_history=False
2356-
)
23572351
if config == "colwise":
23582352
_process_column_sharding(
23592353
layer_subgraph=layer_subgraph,
@@ -2366,7 +2360,6 @@ def detect_sharding_from_config(
23662360
split_dim=SplitDimension.ROW,
23672361
config=transform_container.config,
23682362
dist_op="all_reduce",
2369-
min_local_shape=min_local_shape,
23702363
layer_type=layer_type,
23712364
)
23722365
):
@@ -2393,7 +2386,6 @@ def detect_sharding_from_config(
23932386
split_dim=SplitDimension.COLUMN,
23942387
config=transform_container.config,
23952388
dist_op=None,
2396-
min_local_shape=min_local_shape,
23972389
layer_type=layer_type,
23982390
)
23992391
)
@@ -2404,7 +2396,6 @@ def detect_sharding_from_config(
24042396
split_dim=SplitDimension.ROW,
24052397
config=transform_container.config,
24062398
dist_op="all_reduce",
2407-
min_local_shape=min_local_shape,
24082399
layer_type=layer_type,
24092400
)
24102401
):
@@ -2423,7 +2414,6 @@ def detect_sharding_from_config(
24232414
split_dim=SplitDimension.COLUMN,
24242415
config=transform_container.config,
24252416
dist_op="all_gather",
2426-
min_local_shape=1,
24272417
layer_type=layer_type,
24282418
)
24292419
):
@@ -2536,7 +2526,7 @@ def detect_column_row_shard(
25362526
attention_nodes = list(filtered_nodes(layer_subgraph, is_any_attention_op))
25372527
min_local_shape = 1
25382528

2539-
if config.simple_shard_only:
2529+
if config.simple_shard_only or layer.layer_type == LayerType.UNKNOWN:
25402530
ad_logger.debug(
25412531
f"Forcing Simple Shard on nodes: {nodes_linear} with layer type: {layer.layer_type}"
25422532
)

0 commit comments

Comments
 (0)