Skip to content

Commit b784fe9

Browse files
fixed latent MoE sharding
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
1 parent d6fb7f7 commit b784fe9

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1223,6 +1223,7 @@ def _shard_parameter_node(
12231223

12241224
# Shard weight using the unified function (also updates the parameter)
12251225
weight_nodes = extract_weight_nodes(node)
1226+
12261227
for weight_node in weight_nodes.weights:
12271228
_, weight_new_shape = shard_weight_tensor(
12281229
gm=gm,
@@ -2214,7 +2215,7 @@ def detect_column_row_shard(
22142215
attention_nodes = list(filtered_nodes(layer_subgraph, is_any_attention_op))
22152216
min_local_shape = 1
22162217

2217-
if config.simple_shard_only:
2218+
if config.simple_shard_only or layer.layer_type == LayerType.UNKNOWN:
22182219
ad_logger.debug(
22192220
f"Forcing Simple Shard on nodes: {nodes_linear} with layer type: {layer.layer_type}"
22202221
)

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,9 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
546546
assert gm.graph.nodes, "Graph is empty"
547547
layer_subgraphs = []
548548
linear_nodes = list(filtered_nodes(gm.graph.nodes, is_any_lin_op))
549+
550+
# find the embedding size of this model. Extract it from the input of the first linear node.
551+
embd = get_weight_shape(linear_nodes[0], dim=-1)
549552
unprocessed_linear_nodes = set(linear_nodes)
550553
assert len(linear_nodes) > 0, "Could not find any linear nodes in the graph"
551554

@@ -557,7 +560,7 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
557560
# opening is the list of linear nodes
558561
# layer_subgraph is the list of nodes between the opening and closing linear nodes
559562
# closing is the last linear node in the layer
560-
layer_subgraph = get_layer_after_linear_node(linear_nodes, terminating_indices)
563+
layer_subgraph = get_layer_after_linear_node(linear_nodes, terminating_indices, embd=embd)
561564
if layer_subgraph.opening_nodes is not None and len(layer_subgraph.opening_nodes) > 0:
562565
unprocessed_linear_nodes -= (
563566
set(layer_subgraph.opening_nodes)
@@ -823,6 +826,7 @@ def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[in
823826
def get_layer_after_linear_node(
824827
linear_nodes: List[Node],
825828
terminating_indices: List[int],
829+
embd: Optional[int] = None,
826830
match_on_shapes: bool = True,
827831
enforce_strict_linear_history: bool = True,
828832
) -> LayerSubgraph:
@@ -897,8 +901,9 @@ def filter_condition(node: Node, embd: Optional[int] = None, dim: Optional[int]
897901
layer_type=LayerType.UNKNOWN,
898902
)
899903
if match_on_shapes:
900-
# get embedding size of the opening linear node
901-
embd = get_weight_shape(linear_nodes[start_lin_index], dim=-1)
904+
if embd is None:
905+
# get embedding size of the opening linear node
906+
embd = get_weight_shape(linear_nodes[start_lin_index], dim=-1)
902907
# partial init boundary_condition and filter_condition
903908
boundary_condition = partial(boundary_condition, embd=embd, dim=0)
904909
filter_condition = partial(filter_condition, embd=embd, dim=0)
@@ -951,7 +956,12 @@ def filter_condition(node: Node, embd: Optional[int] = None, dim: Optional[int]
951956
ssm_nodes = list(filtered_nodes(interior_nodes, is_any_ssm_op))
952957
attention_nodes = list(filtered_nodes(interior_nodes, is_any_attention_op))
953958
intermediate_lin_nodes = list(filtered_nodes(interior_nodes, is_any_lin_op))
954-
intermediate_weight_nodes = list(filtered_nodes(interior_nodes, is_weight_node))
959+
intermediate_weight_nodes = list(
960+
filtered_nodes(
961+
interior_nodes, lambda n: is_weight_node(n) and
962+
not is_any_lin_op(list(n.users)[0])
963+
)
964+
)
955965

956966
layer_type = LayerType.MLP
957967
min_local_shape = 1

0 commit comments

Comments
 (0)