Skip to content

Commit 47d9815

Browse files
fixed latent MoE sharding
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
1 parent 6149ede commit 47d9815

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1246,6 +1246,7 @@ def _shard_parameter_node(
12461246

12471247
# Shard weight using the unified function (also updates the parameter)
12481248
weight_nodes = extract_weight_nodes(node)
1249+
12491250
for weight_node in weight_nodes.weights:
12501251
_, weight_new_shape = shard_weight_tensor(
12511252
gm=gm,
@@ -2237,7 +2238,7 @@ def detect_column_row_shard(
22372238
attention_nodes = list(filtered_nodes(layer_subgraph, is_any_attention_op))
22382239
min_local_shape = 1
22392240

2240-
if config.simple_shard_only:
2241+
if config.simple_shard_only or layer.layer_type == LayerType.UNKNOWN:
22412242
ad_logger.debug(
22422243
f"Forcing Simple Shard on nodes: {nodes_linear} with layer type: {layer.layer_type}"
22432244
)

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,9 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
535535
assert gm.graph.nodes, "Graph is empty"
536536
layer_subgraphs = []
537537
linear_nodes = list(filtered_nodes(gm.graph.nodes, is_any_lin_op))
538+
539+
# find the embedding size of this model. Extract it from the input of the first linear node.
540+
embd = get_weight_shape(linear_nodes[0], dim=-1)
538541
unprocessed_linear_nodes = set(linear_nodes)
539542
assert len(linear_nodes) > 0, "Could not find any linear nodes in the graph"
540543

@@ -546,7 +549,7 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
546549
# opening is the list of linear nodes
547550
# layer_subgraph is the list of nodes between the opening and closing linear nodes
548551
# closing is the last linear node in the layer
549-
layer_subgraph = get_layer_after_linear_node(linear_nodes, terminating_indices)
552+
layer_subgraph = get_layer_after_linear_node(linear_nodes, terminating_indices, embd=embd)
550553
if layer_subgraph.opening_nodes is not None and len(layer_subgraph.opening_nodes) > 0:
551554
unprocessed_linear_nodes -= (
552555
set(layer_subgraph.opening_nodes)
@@ -812,6 +815,7 @@ def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[in
812815
def get_layer_after_linear_node(
813816
linear_nodes: List[Node],
814817
terminating_indices: List[int],
818+
embd: Optional[int] = None,
815819
match_on_shapes: bool = True,
816820
enforce_strict_linear_history: bool = True,
817821
) -> LayerSubgraph:
@@ -886,8 +890,9 @@ def filter_condition(node: Node, embd: Optional[int] = None, dim: Optional[int]
886890
layer_type=LayerType.UNKNOWN,
887891
)
888892
if match_on_shapes:
889-
# get embedding size of the opening linear node
890-
embd = get_weight_shape(linear_nodes[start_lin_index], dim=-1)
893+
if embd is None:
894+
# get embedding size of the opening linear node
895+
embd = get_weight_shape(linear_nodes[start_lin_index], dim=-1)
891896
# partial init boundary_condition and filter_condition
892897
boundary_condition = partial(boundary_condition, embd=embd, dim=0)
893898
filter_condition = partial(filter_condition, embd=embd, dim=0)
@@ -940,7 +945,12 @@ def filter_condition(node: Node, embd: Optional[int] = None, dim: Optional[int]
940945
ssm_nodes = list(filtered_nodes(interior_nodes, is_any_ssm_op))
941946
attention_nodes = list(filtered_nodes(interior_nodes, is_any_attention_op))
942947
intermediate_lin_nodes = list(filtered_nodes(interior_nodes, is_any_lin_op))
943-
intermediate_weight_nodes = list(filtered_nodes(interior_nodes, is_weight_node))
948+
intermediate_weight_nodes = list(
949+
filtered_nodes(
950+
interior_nodes, lambda n: is_weight_node(n) and
951+
not is_any_lin_op(list(n.users)[0])
952+
)
953+
)
944954

945955
layer_type = LayerType.MLP
946956
min_local_shape = 1

0 commit comments

Comments
 (0)