Skip to content

Commit a111eb3

Browse files
Fixed logging
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
1 parent 45db051 commit a111eb3

File tree

1 file changed

+55
-36
lines changed
  • tensorrt_llm/_torch/auto_deploy/transform/library

1 file changed

+55
-36
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 55 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ def _process_simple_shard(
160160
nodes_linear: Union[Dict[Node, List[Node]], List[Node]],
161161
rank: int,
162162
world_size: int,
163-
sharding_config: ShardingTransformContainer,
163+
transform_container: ShardingTransformContainer,
164+
layer_type: LayerType = LayerType.MLP,
164165
) -> int:
165166
# for every linear node:
166167
# --> row_split (dim 0 of weight) + all_gather (dim -1 of output)
@@ -172,14 +173,15 @@ def _process_simple_shard(
172173
num_simple_shards = 0
173174
for n in nodes_linear:
174175
num_simple_shards += int(
175-
sharding_config.add(
176+
transform_container.add(
176177
WeightShardingInfo.from_node(
177178
n,
178179
split_dim=SplitDimension.COLUMN,
179180
rank=rank,
180181
world_size=world_size,
181182
dist_op="all_gather",
182183
min_local_shape=1,
184+
layer_type=layer_type,
183185
)
184186
)
185187
)
@@ -268,7 +270,7 @@ def _apply(
268270
def _process_ssm_sharding(
269271
gm: GraphModule,
270272
entry_node: Node,
271-
sharding_config: ShardingTransformContainer,
273+
transform_container: ShardingTransformContainer,
272274
rank: int,
273275
world_size: int,
274276
min_local_shape: int = 1,
@@ -317,7 +319,7 @@ def _process_ssm_sharding(
317319
# ##############################################################
318320
# ####### shard the entry_node (the first linear layer) ########
319321
# ##############################################################
320-
if not sharding_config.add(
322+
if not transform_container.add(
321323
WeightShardingInfo.from_node(
322324
entry_node,
323325
split_dim=SplitDimension.COLUMN,
@@ -339,15 +341,15 @@ def _process_ssm_sharding(
339341
split_args_0[1] = [s // world_size for s in split_args_0[1]]
340342
split_args_1 = list(split_nodes[1].args)
341343
split_args_1[1] = [s // world_size for s in split_args_1[1]]
342-
sharding_config.add(
344+
transform_container.add(
343345
ParameterUpdateInfo(
344346
rank=rank,
345347
world_size=world_size,
346348
target_node=split_nodes[0].name,
347349
args=tuple(split_args_0),
348350
)
349351
)
350-
sharding_config.add(
352+
transform_container.add(
351353
ParameterUpdateInfo(
352354
rank=rank,
353355
world_size=world_size,
@@ -368,7 +370,7 @@ def _process_ssm_sharding(
368370
# This one is also sharded, so we need to update this parameter
369371
conv_args = list(conv1d_node.args)
370372
conv_args[-1] = conv1d_node.args[-1] // world_size
371-
sharding_config.add(
373+
transform_container.add(
372374
ParameterUpdateInfo(
373375
rank=rank, world_size=world_size, target_node=conv1d_node.name, args=tuple(conv_args)
374376
)
@@ -400,7 +402,7 @@ def _process_ssm_sharding(
400402
break
401403

402404
# Shard the weight tensor (also updates the parameter in the module)
403-
sharding_config.add(
405+
transform_container.add(
404406
WeightShardingInfo.from_node(
405407
list(weight_node.users)[0],
406408
split_dim=SplitDimension.COLUMN,
@@ -429,7 +431,7 @@ def _process_ssm_sharding(
429431
args = list(view_node.args)
430432
view_shape[2] = view_shape[2] // world_size
431433
args[1] = tuple(view_shape)
432-
sharding_config.add(
434+
transform_container.add(
433435
ParameterUpdateInfo(
434436
rank=rank, world_size=world_size, target_node=view_node.name, args=tuple(args)
435437
)
@@ -439,7 +441,7 @@ def _process_ssm_sharding(
439441
##############################################################
440442
############## shard the out_proj node #######################
441443
##############################################################
442-
sharding_config.add(
444+
transform_container.add(
443445
WeightShardingInfo.from_node(
444446
out_proj_node,
445447
split_dim=SplitDimension.ROW,
@@ -589,7 +591,7 @@ def detect_sharding_from_config(
589591
TODO: currently, it applies only to TP sharding.
590592
Args:
591593
gm: Graph module to apply transformations to
592-
sharding_config: Predefined sharding configuration
594+
transform_container: containing predefined sharding configuration
593595
"""
594596
# check if config is valid.
595597
# 1. it is a Dict[str, str]
@@ -635,6 +637,7 @@ def detect_sharding_from_config(
635637
num_shards = 0
636638
num_simple_shards = 0
637639
num_row_col_shards = 0
640+
num_attention_shards = 0
638641
num_ssm_shards = 0
639642

640643
for lin_node in filtered_nodes(gm.graph.nodes, is_any_lin_op):
@@ -657,7 +660,6 @@ def detect_sharding_from_config(
657660
pattern_string = pattern_string.replace("*", "@")
658661
pattern_regex = re.escape(pattern_string).replace("@", ".*")
659662
if re.match(pattern_regex, module_name):
660-
num_shards += 1
661663
# we have a match. Get the config for this layer
662664
config = tp_plan[key]
663665
if config == "colwise":
@@ -681,11 +683,17 @@ def detect_sharding_from_config(
681683
layer_type=layer_type,
682684
)
683685
):
686+
if layer_type == LayerType.ATTENTION:
687+
num_attention_shards += 1
684688
num_row_col_shards += 1
685689
elif config == "mamba":
686-
num_ssm_shards += int(
690+
if (
687691
_process_ssm_sharding(gm, lin_node, transform_container, rank, world_size)
688-
)
692+
> 0
693+
):
694+
num_ssm_shards += 1
695+
num_row_col_shards += 1
696+
689697
elif "sequence" in config:
690698
# TODO: Sequence parallelism is not supported yet.
691699
ad_logger.warning("Sequence parallelism is not supported yet. Skipping.")
@@ -747,9 +755,10 @@ def detect_sharding_from_config(
747755
# after successful match, break the loop
748756
break
749757

758+
num_shards = num_simple_shards + num_row_col_shards
750759
ad_logger.info(
751-
f"Applied {num_shards} TP shards from config (simple: {num_simple_shards}, "
752-
f"row-col pattern: {num_row_col_shards}, ssm: {num_ssm_shards})"
760+
f"Applied {num_shards} TP shards from config. Simple: {num_simple_shards}, "
761+
f"row-col: {num_row_col_shards} (including: ssm: {num_ssm_shards}, attention: {num_attention_shards})"
753762
)
754763

755764
num_matches = len(transform_container.weight_sharding_transforms)
@@ -799,7 +808,7 @@ def detect_ssm_shard(
799808

800809
def detect_column_row_shard(
801810
gm: GraphModule,
802-
sharding_config: ShardingTransformContainer,
811+
transfrom_container: ShardingTransformContainer,
803812
) -> TransformInfo:
804813
"""A transformation to apply sharding to the model following tensor parallelism.
805814
@@ -818,7 +827,7 @@ def detect_column_row_shard(
818827
splitting, e.g., the individual heads into smaller shards.
819828
"""
820829
ad_logger.debug("Before sharding graph: " + str(gm))
821-
rank, world_size = sharding_config.rank, sharding_config.world_size
830+
rank, world_size = transfrom_container.rank, transfrom_container.world_size
822831

823832
assert isinstance(gm, GraphModule), "Expecting GraphModule"
824833
ad_logger.info("Running TP sharding detection")
@@ -838,37 +847,44 @@ def detect_column_row_shard(
838847
nodes_linear = opening + [closing]
839848
num_shards += 1
840849

841-
if sharding_config.simple_shard_only:
842-
ad_logger.debug(f"Forcing Simple Shard on nodes: {nodes_linear}")
843-
num_simple_shards += _process_simple_shard(
844-
nodes_linear, rank, world_size, sharding_config
845-
)
846-
continue
847-
848850
ssm_nodes = list(filtered_nodes(layer_subgraph, is_any_ssm_op))
849851
attention_nodes = list(filtered_nodes(layer_subgraph, is_any_attention_op))
850852
min_local_shape = 1
851-
layer_type = LayerType.MLP
853+
layer_type = (
854+
LayerType.MAMBA
855+
if len(ssm_nodes) > 0
856+
else LayerType.ATTENTION
857+
if len(attention_nodes) > 0
858+
else LayerType.MLP
859+
)
860+
861+
if transfrom_container.simple_shard_only:
862+
ad_logger.debug(
863+
f"Forcing Simple Shard on nodes: {nodes_linear} with layer type: {layer_type}"
864+
)
865+
num_simple_shards += _process_simple_shard(
866+
nodes_linear, rank, world_size, transfrom_container, layer_type=layer_type
867+
)
868+
continue
852869

853870
if len(ssm_nodes) > 0:
854871
# Mamba layers need special handling due to the fused weights for in_proj and conv1d
855872
assert len(ssm_nodes) == 1, "Expected exactly one SSM node in layer subgraph"
856873
assert len(opening) == 1, "Expected exactly one opening node in Mamba layer"
857874
ad_logger.debug(f"Found SSM nodes in layer subgraph: {ssm_nodes}")
858875
num_ssm_shards += _process_ssm_sharding(
859-
gm, opening[0], sharding_config, rank, world_size
876+
gm, opening[0], transfrom_container, rank, world_size
860877
)
861878
continue
862879

863880
if len(attention_nodes) > 0:
864-
layer_type = LayerType.ATTENTION
865881
ad_logger.debug(f"Found attention nodes in layer subgraph: {attention_nodes}")
866882
if len(attention_nodes) > 1:
867883
# Column-row shard boundary region detection is probably wrong - there should be
868884
# only one attention operation. Fall back to simple shard.
869885
ad_logger.debug(f"More than one attention node: {attention_nodes}")
870886
num_simple_shards += _process_simple_shard(
871-
nodes_linear, rank, world_size, sharding_config
887+
nodes_linear, rank, world_size, transfrom_container, layer_type=layer_type
872888
)
873889
continue
874890
# Extract head dimension. We cannot shard below the head_dim size.
@@ -890,26 +906,27 @@ def detect_column_row_shard(
890906
f"Falling back to simple shard."
891907
)
892908
num_simple_shards += _process_simple_shard(
893-
nodes_linear, rank, world_size, sharding_config
909+
nodes_linear,
910+
rank,
911+
world_size,
912+
transfrom_container,
913+
layer_type=layer_type,
894914
)
895915
# TODO: handle the case where num_kv_heads is not divisible by world_size
896916
continue
897-
num_attention_shards += 1
898-
else:
899-
layer_type = LayerType.MLP
900917

901918
# column-row sharding
902919
_process_column_sharding(
903920
linear_nodes=opening,
904921
subgraph_nodes=layer_subgraph,
905-
transform_container=sharding_config,
922+
transform_container=transfrom_container,
906923
rank=rank,
907924
world_size=world_size,
908925
min_local_shape=min_local_shape,
909926
)
910927

911928
# shard single row node
912-
if sharding_config.add(
929+
if transfrom_container.add(
913930
WeightShardingInfo.from_node(
914931
closing,
915932
split_dim=SplitDimension.ROW,
@@ -921,10 +938,12 @@ def detect_column_row_shard(
921938
)
922939
):
923940
num_column_row_shards += 1
941+
if layer_type == LayerType.ATTENTION:
942+
num_attention_shards += 1
924943

925944
# simple shard remaining linear nodes
926945
num_simple_shards += _process_simple_shard(
927-
unprocessed_linear_nodes, rank, world_size, sharding_config
946+
unprocessed_linear_nodes, rank, world_size, transfrom_container
928947
)
929948
num_column_row_shards += num_ssm_shards
930949
ad_logger.info(

0 commit comments

Comments
 (0)