Skip to content

Commit 4139524

Browse files
fixed rebase issue
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
1 parent b3146d0 commit 4139524

File tree

4 files changed

+98
-93
lines changed

4 files changed

+98
-93
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 24 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
extract_weight_nodes,
4444
filtered_nodes,
4545
get_all_layer_subgraphs,
46-
get_layer_after_linear_node,
46+
get_all_weights_in_subgraph,
4747
is_any_attention_op,
4848
is_any_lin_op,
4949
is_any_moe_op,
@@ -1060,31 +1060,6 @@ def _resolve_tp_cls_from_node(node: Node):
10601060
return WeightShardingInfo
10611061

10621062

1063-
def _get_dim0_from_arg(gm: GraphModule, arg: Union[Node, torch.Tensor]) -> int:
1064-
"""Helper to get the first dimension size of an argument (Node or Tensor)."""
1065-
if isinstance(arg, torch.Tensor):
1066-
return arg.shape[0]
1067-
if isinstance(arg, Node):
1068-
if arg.op == "get_attr":
1069-
# Traverse attributes to find the tensor
1070-
obj = gm
1071-
for atom in arg.target.split("."):
1072-
obj = getattr(obj, atom)
1073-
return obj.shape[0]
1074-
if "val" in arg.meta:
1075-
return shape(arg)[0]
1076-
raise ValueError(f"Cannot determine shape[0] for {arg}")
1077-
1078-
1079-
def get_all_weights_in_subgraph(
1080-
sources: list[Node],
1081-
sinks: list[Node],
1082-
):
1083-
"""Get all weight nodes (get_attr nodes) in the subgraph between sources and sinks."""
1084-
weight_nodes = subgraph(sources, sinks, include=lambda n: n.op == "get_attr")
1085-
return weight_nodes
1086-
1087-
10881063
def init_process_grid_from_config(
10891064
config: ShardingTransformConfig,
10901065
) -> Dict[ShardingDim, Dict[str, int]]:
@@ -1247,6 +1222,7 @@ def _shard_parameter_node(
12471222

12481223
# Shard weight using the unified function (also updates the parameter)
12491224
weight_nodes = extract_weight_nodes(node)
1225+
12501226
for weight_node in weight_nodes.weights:
12511227
_, weight_new_shape = shard_weight_tensor(
12521228
gm=gm,
@@ -1532,9 +1508,7 @@ def _insert_sharded_mxfp4_mlp_ep(
15321508

15331509
# Add a dist all-reduce after the op (sum partial results across EP ranks)
15341510
with gm.graph.inserting_after(node):
1535-
red = gm.graph.call_function(
1536-
torch.ops.auto_deploy.torch_dist_all_reduce, args=(node, config.allreduce_strategy.name)
1537-
)
1511+
red = gm.graph.call_function(torch.ops.auto_deploy.torch_dist_all_reduce, args=(node,))
15381512
node.replace_all_uses_with(red)
15391513
# keep dataflow: red(input=node)
15401514
red.replace_input_with(red, node)
@@ -2018,47 +1992,37 @@ def detect_sharding_from_config(
20181992
raise ValueError(f"Unsupported sharding source: {source}")
20191993
tp_plan = config["tp_plan"]
20201994

2021-
# If the node is inside the attention module, we need to set min_local_shape to the
2022-
# head_dim - otherwise, we would risk splitting the heads into smaller shards.
2023-
# TODO: is there a better way to check if we are in attention module?
2024-
attn_names = [
2025-
"attention",
2026-
"Attention",
2027-
"attn",
2028-
"Attn",
2029-
"q_proj",
2030-
"k_proj",
2031-
"v_proj",
2032-
"o_proj",
2033-
]
2034-
20351995
num_shards = 0
20361996
num_simple_shards = 0
20371997
num_row_col_shards = 0
20381998
num_attention_shards = 0
20391999
num_ssm_shards = 0
2040-
head_dim = -1
20412000
linear_nodes = list(filtered_nodes(gm.graph.nodes, is_any_lin_op))
20422001

2002+
# use layer_subgraphs to determine the layer_type
2003+
# and check the validity of the sharding transform
2004+
layer_subgraphs, unprocessed_linear_nodes = get_all_layer_subgraphs(gm)
2005+
20432006
for lin_node in linear_nodes:
20442007
# use node's weight name to get the module name
20452008
weight_name = extract_weight_name(lin_node)
2046-
2047-
if any(attn_name in weight_name for attn_name in attn_names):
2048-
# find the next attention node and infer the head_dim
2049-
next_attention_node, _ = bfs(
2050-
lin_node, is_any_attention_op, attr_next="users", include_root=False
2051-
)
2052-
if next_attention_node is None:
2053-
# this is the last attention node in the graph. Take the previously found head_dim
2054-
assert head_dim != -1, "Head dim not found for the last attention node"
2055-
else:
2056-
head_dim = shape(next_attention_node)[-1]
2057-
min_local_shape = head_dim
2058-
layer_type = LayerType.ATTENTION
2009+
# get the parent layer_subgraph
2010+
layer_subgraph = [
2011+
layer
2012+
for layer in layer_subgraphs
2013+
if lin_node in layer.opening_nodes or lin_node == layer.terminating_node
2014+
]
2015+
if len(layer_subgraph) == 1:
2016+
layer_subgraph = layer_subgraph[0]
2017+
layer_type = layer_subgraph.layer_type
20592018
else:
2060-
min_local_shape = 1
2061-
layer_type = LayerType.MLP
2019+
if lin_node in unprocessed_linear_nodes:
2020+
layer_type = LayerType.UNKNOWN
2021+
else:
2022+
ad_logger.warning(
2023+
f"Failed to find the parent layer_subgraph for linear node {lin_node}. Skipping."
2024+
)
2025+
continue
20622026

20632027
# use regex to find if module_name matches any of the keys in sharding_config
20642028
for key in tp_plan.keys():
@@ -2072,11 +2036,6 @@ def detect_sharding_from_config(
20722036
# we have a match. Get the config for this layer
20732037
config = tp_plan[key]
20742038

2075-
if config in ["colwise", "mamba"]:
2076-
cur_node_index = linear_nodes.index(lin_node)
2077-
layer_subgraph = get_layer_after_linear_node(
2078-
linear_nodes, [cur_node_index - 1], enforce_strict_linear_history=False
2079-
)
20802039
if config == "colwise":
20812040
_process_column_sharding(
20822041
layer_subgraph=layer_subgraph,
@@ -2089,7 +2048,6 @@ def detect_sharding_from_config(
20892048
split_dim=SplitDimension.ROW,
20902049
config=transform_container.config,
20912050
dist_op="all_reduce",
2092-
min_local_shape=min_local_shape,
20932051
layer_type=layer_type,
20942052
)
20952053
):
@@ -2116,7 +2074,6 @@ def detect_sharding_from_config(
21162074
split_dim=SplitDimension.COLUMN,
21172075
config=transform_container.config,
21182076
dist_op=None,
2119-
min_local_shape=min_local_shape,
21202077
layer_type=layer_type,
21212078
)
21222079
)
@@ -2127,7 +2084,6 @@ def detect_sharding_from_config(
21272084
split_dim=SplitDimension.ROW,
21282085
config=transform_container.config,
21292086
dist_op="all_reduce",
2130-
min_local_shape=min_local_shape,
21312087
layer_type=layer_type,
21322088
)
21332089
):
@@ -2146,7 +2102,6 @@ def detect_sharding_from_config(
21462102
split_dim=SplitDimension.COLUMN,
21472103
config=transform_container.config,
21482104
dist_op="all_gather",
2149-
min_local_shape=1,
21502105
layer_type=layer_type,
21512106
)
21522107
):
@@ -2259,7 +2214,7 @@ def detect_column_row_shard(
22592214
attention_nodes = list(filtered_nodes(layer_subgraph, is_any_attention_op))
22602215
min_local_shape = 1
22612216

2262-
if config.simple_shard_only:
2217+
if config.simple_shard_only or layer.layer_type == LayerType.UNKNOWN:
22632218
ad_logger.debug(
22642219
f"Forcing Simple Shard on nodes: {nodes_linear} with layer type: {layer.layer_type}"
22652220
)

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,19 @@ def get_quantization_params_from_linear_node(linear_op: torch.fx.node.Node):
143143
return input_params, weight_params, output_params
144144

145145

146-
def extract_weight_name(node: Node) -> str:
146+
def get_all_weights_in_subgraph(
147+
sources: list[Node],
148+
sinks: list[Node],
149+
):
150+
"""Get all weight nodes (get_attr nodes) in the subgraph between sources and sinks."""
151+
weight_nodes = subgraph(sources, sinks, include=is_weight_node)
152+
return weight_nodes
153+
154+
155+
def extract_weight_name(node: Node) -> Union[str, bool]:
147156
weight_nodes = extract_weight_nodes(node)
157+
if len(weight_nodes.weights) == 0:
158+
return False
148159
return weight_nodes.weights[0].node_key
149160

150161

@@ -431,6 +442,10 @@ def is_dist_op(node: Node) -> bool:
431442
return is_op(node, dist_ops)
432443

433444

445+
def is_weight_node(node: Node) -> bool:
446+
return node.op == "get_attr" and node.target and has_shape(node) and len(shape(node)) > 0
447+
448+
434449
def get_user_if_pattern_match(node, ops, numusers, user_idx: int = 0):
435450
"""Get a user from a node if the node matches a given op set and num of users."""
436451
if node is None:
@@ -531,6 +546,9 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
531546
assert gm.graph.nodes, "Graph is empty"
532547
layer_subgraphs = []
533548
linear_nodes = list(filtered_nodes(gm.graph.nodes, is_any_lin_op))
549+
550+
# find the embedding size of this model. Extract it from the input of the first linear node.
551+
embd = get_weight_shape(linear_nodes[0], dim=-1)
534552
unprocessed_linear_nodes = set(linear_nodes)
535553
assert len(linear_nodes) > 0, "Could not find any linear nodes in the graph"
536554

@@ -542,7 +560,7 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
542560
# opening is the list of linear nodes
543561
# layer_subgraph is the list of nodes between the opening and closing linear nodes
544562
# closing is the last linear node in the layer
545-
layer_subgraph = get_layer_after_linear_node(linear_nodes, terminating_indices)
563+
layer_subgraph = get_layer_after_linear_node(linear_nodes, terminating_indices, embd=embd)
546564
if layer_subgraph.opening_nodes is not None and len(layer_subgraph.opening_nodes) > 0:
547565
unprocessed_linear_nodes -= (
548566
set(layer_subgraph.opening_nodes)
@@ -808,6 +826,7 @@ def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[in
808826
def get_layer_after_linear_node(
809827
linear_nodes: List[Node],
810828
terminating_indices: List[int],
829+
embd: Optional[int] = None,
811830
match_on_shapes: bool = True,
812831
enforce_strict_linear_history: bool = True,
813832
) -> LayerSubgraph:
@@ -882,8 +901,9 @@ def filter_condition(node: Node, embd: Optional[int] = None, dim: Optional[int]
882901
layer_type=LayerType.UNKNOWN,
883902
)
884903
if match_on_shapes:
885-
# get embedding size of the opening linear node
886-
embd = get_weight_shape(linear_nodes[start_lin_index], dim=-1)
904+
if embd is None:
905+
# get embedding size of the opening linear node
906+
embd = get_weight_shape(linear_nodes[start_lin_index], dim=-1)
887907
# partial init boundary_condition and filter_condition
888908
boundary_condition = partial(boundary_condition, embd=embd, dim=0)
889909
filter_condition = partial(filter_condition, embd=embd, dim=0)
@@ -892,6 +912,18 @@ def filter_condition(node: Node, embd: Optional[int] = None, dim: Optional[int]
892912
sources=[linear_nodes[start_lin_index]], boundary_condition=boundary_condition
893913
)
894914
lin_nodes_in_subgraph = list(filtered_nodes(forward_subgraph, filter_condition))
915+
if len(lin_nodes_in_subgraph) > 1:
916+
# it means that probably we went over the boundary of the layer.
917+
# It may happen e.g., with MoLE (latent MoE), with the closing latent fc2 projection,
918+
# when the subgraph spanned over fc2 "spills" over consecutive layers.
919+
# Then, wrap this single linear node in LayerType.UNKNOWN and return.
920+
terminating_indices.append(start_lin_index)
921+
return LayerSubgraph(
922+
opening_nodes=[linear_nodes[start_lin_index]],
923+
subgraph_nodes=[],
924+
terminating_node=linear_nodes[start_lin_index],
925+
layer_type=LayerType.UNKNOWN,
926+
)
895927
start_lin_index += 1
896928
start_lin_index -= 1
897929
terminating_linear_node = lin_nodes_in_subgraph[0]
@@ -924,25 +956,39 @@ def filter_condition(node: Node, embd: Optional[int] = None, dim: Optional[int]
924956
ssm_nodes = list(filtered_nodes(interior_nodes, is_any_ssm_op))
925957
attention_nodes = list(filtered_nodes(interior_nodes, is_any_attention_op))
926958
intermediate_lin_nodes = list(filtered_nodes(interior_nodes, is_any_lin_op))
959+
intermediate_weight_nodes = list(
960+
filtered_nodes(
961+
interior_nodes, lambda n: is_weight_node(n) and not is_any_lin_op(list(n.users)[0])
962+
)
963+
)
927964

928965
layer_type = LayerType.MLP
929966
min_local_shape = 1
930967
if len(ssm_nodes) > 0:
931-
assert len(ssm_nodes) == 1, "SSM layer must have exactly one SSM node"
932-
layer_type = LayerType.SSM
933-
# determine head size
934-
min_local_shape = shape(ssm_nodes[0])[-1]
935-
if len(attention_nodes) > 0:
936-
assert len(attention_nodes) == 1, "Attention layer must have exactly one attention node"
937-
layer_type = LayerType.ATTENTION
938-
# determine head size
939-
min_local_shape = shape(attention_nodes[0])[-1]
940-
if len(intermediate_lin_nodes) > 0:
941-
assert len(intermediate_lin_nodes) == 2, (
942-
"MLA layer must have exactly two intermediate linear nodes"
943-
)
944-
assert len(attention_nodes) == 1, "MLA layer must have exactly one attention node"
945-
layer_type = LayerType.MLA
968+
if len(ssm_nodes) == 1:
969+
layer_type = LayerType.SSM
970+
# determine head size
971+
min_local_shape = shape(ssm_nodes[0])[-1]
972+
else:
973+
layer_type = LayerType.UNKNOWN
974+
if len(attention_nodes) > 0 and layer_type != LayerType.UNKNOWN:
975+
if len(attention_nodes) == 1:
976+
layer_type = LayerType.ATTENTION
977+
# determine head size
978+
min_local_shape = shape(attention_nodes[0])[-1]
979+
else:
980+
layer_type = LayerType.UNKNOWN
981+
if len(intermediate_lin_nodes) > 0 and layer_type != LayerType.UNKNOWN:
982+
if len(intermediate_lin_nodes) == 2 and len(attention_nodes) == 1:
983+
layer_type = LayerType.MLA
984+
else:
985+
layer_type = LayerType.UNKNOWN
986+
# only SSM or MLA layers can have weight nodes in the interior nodes
987+
# TODO: Minimax does have RMSNorm inside attention, we need to
988+
# support it in the future.
989+
if len(intermediate_weight_nodes) > 0:
990+
if layer_type not in [LayerType.SSM, LayerType.MLA]:
991+
layer_type = LayerType.UNKNOWN
946992

947993
layer_subgraph = LayerSubgraph(
948994
opening_nodes=opening_linear_nodes,

0 commit comments

Comments
 (0)