Skip to content

Commit 6de8a85

Browse files
cleanup in progress
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent ca705a9 commit 6de8a85

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ...shim.interface import CachedSequenceInterface
1414
from ...utils.cuda_mem_tracker import cuda_memory_tracker
1515
from ...utils.logger import ad_logger
16-
from ...utils.node_utils import extract_param_names_from_node, is_linear_op, is_op
16+
from ...utils.node_utils import extract_weight_name, is_linear_op, is_op
1717
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
1818

1919

@@ -36,7 +36,7 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node
3636
y2 = y[:, out1:out1+out2]
3737
"""
3838
# some info we need
39-
keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes]
39+
keys_unfused = [extract_weight_name(n) for n in linear_nodes]
4040
params_unfused = [gm.get_parameter(k) for k in keys_unfused]
4141
sizes_unfused = [p.size(0) for p in params_unfused]
4242
key_fused = f"fused_weight_{idx}"
@@ -128,7 +128,7 @@ def build_custom_args_for_linear(self, scale_getattrs: Dict[str, Node]) -> Tuple
128128
def _insert_fused_quant_gemm(
129129
self, gm: GraphModule, idx: int, parent_node: Node, linear_nodes: List[Node]
130130
):
131-
keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes]
131+
keys_unfused = [extract_weight_name(n) for n in linear_nodes]
132132
params_unfused = [gm.get_parameter(k) for k in keys_unfused]
133133
sizes_unfused = [p.size(0) for p in params_unfused]
134134
key_fused = f"fused_weight_{idx}"

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2336,7 +2336,6 @@ def detect_column_row_shard(
23362336
min_local_shape is the minimum size of the local tensor shard, to prevent TP parallelism
23372337
splitting, e.g., the individual heads into smaller shards.
23382338
"""
2339-
# test_moe_variants()
23402339
ad_logger.debug("Before sharding graph: " + str(gm))
23412340
config = transform_container.config
23422341
world_size = config.world_size
@@ -2441,7 +2440,7 @@ def detect_column_row_shard(
24412440
# simple shard remaining linear nodes
24422441
if config.shard_all_unprocessed:
24432442
num_simple_shards += _process_simple_shard(unprocessed_linear_nodes, transform_container)
2444-
num_column_row_shards += num_ssm_shards
2443+
num_column_row_shards += num_ssm_shards + num_mla_shards
24452444
num_shards = num_simple_shards + num_column_row_shards
24462445
ad_logger.info(
24472446
f"Heuristics found {num_shards} TP shards. Simple: {num_simple_shards}, "

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,8 +575,10 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]:
575575
# closing is the last linear node in the layer
576576
layer_subgraph = get_layer_after_linear_node(linear_nodes, terminating_indices)
577577
if layer_subgraph.opening_nodes is not None and len(layer_subgraph.opening_nodes) > 0:
578-
unprocessed_linear_nodes -= set(layer_subgraph.opening_nodes) | set(
579-
[layer_subgraph.terminating_node]
578+
unprocessed_linear_nodes -= (
579+
set(layer_subgraph.opening_nodes)
580+
| set([layer_subgraph.terminating_node])
581+
| set(layer_subgraph.subgraph_nodes)
580582
)
581583
layer_subgraphs.append(layer_subgraph)
582584
last_lin_index = terminating_indices[-1] + 1

0 commit comments

Comments
 (0)