@@ -160,7 +160,8 @@ def _process_simple_shard(
160160 nodes_linear : Union [Dict [Node , List [Node ]], List [Node ]],
161161 rank : int ,
162162 world_size : int ,
163- sharding_config : ShardingTransformContainer ,
163+ transform_container : ShardingTransformContainer ,
164+ layer_type : LayerType = LayerType .MLP ,
164165) -> int :
165166 # for every linear node:
166167 # --> row_split (dim 0 of weight) + all_gather (dim -1 of output)
@@ -172,14 +173,15 @@ def _process_simple_shard(
172173 num_simple_shards = 0
173174 for n in nodes_linear :
174175 num_simple_shards += int (
175- sharding_config .add (
176+ transform_container .add (
176177 WeightShardingInfo .from_node (
177178 n ,
178179 split_dim = SplitDimension .COLUMN ,
179180 rank = rank ,
180181 world_size = world_size ,
181182 dist_op = "all_gather" ,
182183 min_local_shape = 1 ,
184+ layer_type = layer_type ,
183185 )
184186 )
185187 )
@@ -268,7 +270,7 @@ def _apply(
268270def _process_ssm_sharding (
269271 gm : GraphModule ,
270272 entry_node : Node ,
271- sharding_config : ShardingTransformContainer ,
273+ transform_container : ShardingTransformContainer ,
272274 rank : int ,
273275 world_size : int ,
274276 min_local_shape : int = 1 ,
@@ -317,7 +319,7 @@ def _process_ssm_sharding(
317319 # ##############################################################
318320 # ####### shard the entry_node (the first linear layer) ########
319321 # ##############################################################
320- if not sharding_config .add (
322+ if not transform_container .add (
321323 WeightShardingInfo .from_node (
322324 entry_node ,
323325 split_dim = SplitDimension .COLUMN ,
@@ -339,15 +341,15 @@ def _process_ssm_sharding(
339341 split_args_0 [1 ] = [s // world_size for s in split_args_0 [1 ]]
340342 split_args_1 = list (split_nodes [1 ].args )
341343 split_args_1 [1 ] = [s // world_size for s in split_args_1 [1 ]]
342- sharding_config .add (
344+ transform_container .add (
343345 ParameterUpdateInfo (
344346 rank = rank ,
345347 world_size = world_size ,
346348 target_node = split_nodes [0 ].name ,
347349 args = tuple (split_args_0 ),
348350 )
349351 )
350- sharding_config .add (
352+ transform_container .add (
351353 ParameterUpdateInfo (
352354 rank = rank ,
353355 world_size = world_size ,
@@ -368,7 +370,7 @@ def _process_ssm_sharding(
368370 # This one is also sharded, so we need to update this parameter
369371 conv_args = list (conv1d_node .args )
370372 conv_args [- 1 ] = conv1d_node .args [- 1 ] // world_size
371- sharding_config .add (
373+ transform_container .add (
372374 ParameterUpdateInfo (
373375 rank = rank , world_size = world_size , target_node = conv1d_node .name , args = tuple (conv_args )
374376 )
@@ -400,7 +402,7 @@ def _process_ssm_sharding(
400402 break
401403
402404 # Shard the weight tensor (also updates the parameter in the module)
403- sharding_config .add (
405+ transform_container .add (
404406 WeightShardingInfo .from_node (
405407 list (weight_node .users )[0 ],
406408 split_dim = SplitDimension .COLUMN ,
@@ -429,7 +431,7 @@ def _process_ssm_sharding(
429431 args = list (view_node .args )
430432 view_shape [2 ] = view_shape [2 ] // world_size
431433 args [1 ] = tuple (view_shape )
432- sharding_config .add (
434+ transform_container .add (
433435 ParameterUpdateInfo (
434436 rank = rank , world_size = world_size , target_node = view_node .name , args = tuple (args )
435437 )
@@ -439,7 +441,7 @@ def _process_ssm_sharding(
439441 ##############################################################
440442 ############## shard the out_proj node #######################
441443 ##############################################################
442- sharding_config .add (
444+ transform_container .add (
443445 WeightShardingInfo .from_node (
444446 out_proj_node ,
445447 split_dim = SplitDimension .ROW ,
@@ -589,7 +591,7 @@ def detect_sharding_from_config(
589591 TODO: currently, it applies only to TP sharding.
590592 Args:
591593 gm: Graph module to apply transformations to
592- sharding_config: Predefined sharding configuration
594+ transform_container: containing predefined sharding configuration
593595 """
594596 # check if config is valid.
595597 # 1. it is a Dict[str, str]
@@ -635,6 +637,7 @@ def detect_sharding_from_config(
635637 num_shards = 0
636638 num_simple_shards = 0
637639 num_row_col_shards = 0
640+ num_attention_shards = 0
638641 num_ssm_shards = 0
639642
640643 for lin_node in filtered_nodes (gm .graph .nodes , is_any_lin_op ):
@@ -657,7 +660,6 @@ def detect_sharding_from_config(
657660 pattern_string = pattern_string .replace ("*" , "@" )
658661 pattern_regex = re .escape (pattern_string ).replace ("@" , ".*" )
659662 if re .match (pattern_regex , module_name ):
660- num_shards += 1
661663 # we have a match. Get the config for this layer
662664 config = tp_plan [key ]
663665 if config == "colwise" :
@@ -681,11 +683,17 @@ def detect_sharding_from_config(
681683 layer_type = layer_type ,
682684 )
683685 ):
686+ if layer_type == LayerType .ATTENTION :
687+ num_attention_shards += 1
684688 num_row_col_shards += 1
685689 elif config == "mamba" :
686- num_ssm_shards += int (
690+ if (
687691 _process_ssm_sharding (gm , lin_node , transform_container , rank , world_size )
688- )
692+ > 0
693+ ):
694+ num_ssm_shards += 1
695+ num_row_col_shards += 1
696+
689697 elif "sequence" in config :
690698 # TODO: Sequence parallelism is not supported yet.
691699 ad_logger .warning ("Sequence parallelism is not supported yet. Skipping." )
@@ -747,9 +755,10 @@ def detect_sharding_from_config(
747755 # after successful match, break the loop
748756 break
749757
758+ num_shards = num_simple_shards + num_row_col_shards
750759 ad_logger .info (
751- f"Applied { num_shards } TP shards from config (simple : { num_simple_shards } , "
752- f"row-col pattern : { num_row_col_shards } , ssm: { num_ssm_shards } )"
760+ f"Applied { num_shards } TP shards from config. Simple : { num_simple_shards } , "
761+ f"row-col: { num_row_col_shards } (including: ssm: { num_ssm_shards } , attention: { num_attention_shards } )"
753762 )
754763
755764 num_matches = len (transform_container .weight_sharding_transforms )
@@ -799,7 +808,7 @@ def detect_ssm_shard(
799808
800809def detect_column_row_shard (
801810 gm : GraphModule ,
802- sharding_config : ShardingTransformContainer ,
811+ transfrom_container : ShardingTransformContainer ,
803812) -> TransformInfo :
804813 """A transformation to apply sharding to the model following tensor parallelism.
805814
@@ -818,7 +827,7 @@ def detect_column_row_shard(
818827 splitting, e.g., the individual heads into smaller shards.
819828 """
820829 ad_logger .debug ("Before sharding graph: " + str (gm ))
821- rank , world_size = sharding_config .rank , sharding_config .world_size
830+ rank , world_size = transfrom_container .rank , transfrom_container .world_size
822831
823832 assert isinstance (gm , GraphModule ), "Expecting GraphModule"
824833 ad_logger .info ("Running TP sharding detection" )
@@ -838,37 +847,44 @@ def detect_column_row_shard(
838847 nodes_linear = opening + [closing ]
839848 num_shards += 1
840849
841- if sharding_config .simple_shard_only :
842- ad_logger .debug (f"Forcing Simple Shard on nodes: { nodes_linear } " )
843- num_simple_shards += _process_simple_shard (
844- nodes_linear , rank , world_size , sharding_config
845- )
846- continue
847-
848850 ssm_nodes = list (filtered_nodes (layer_subgraph , is_any_ssm_op ))
849851 attention_nodes = list (filtered_nodes (layer_subgraph , is_any_attention_op ))
850852 min_local_shape = 1
851- layer_type = LayerType .MLP
853+ layer_type = (
854+ LayerType .MAMBA
855+ if len (ssm_nodes ) > 0
856+ else LayerType .ATTENTION
857+ if len (attention_nodes ) > 0
858+ else LayerType .MLP
859+ )
860+
861+ if transfrom_container .simple_shard_only :
862+ ad_logger .debug (
863+ f"Forcing Simple Shard on nodes: { nodes_linear } with layer type: { layer_type } "
864+ )
865+ num_simple_shards += _process_simple_shard (
866+ nodes_linear , rank , world_size , transfrom_container , layer_type = layer_type
867+ )
868+ continue
852869
853870 if len (ssm_nodes ) > 0 :
854871 # Mamba layers need special handling due to the fused weights for in_proj and conv1d
855872 assert len (ssm_nodes ) == 1 , "Expected exactly one SSM node in layer subgraph"
856873 assert len (opening ) == 1 , "Expected exactly one opening node in Mamba layer"
857874 ad_logger .debug (f"Found SSM nodes in layer subgraph: { ssm_nodes } " )
858875 num_ssm_shards += _process_ssm_sharding (
859- gm , opening [0 ], sharding_config , rank , world_size
876+ gm , opening [0 ], transfrom_container , rank , world_size
860877 )
861878 continue
862879
863880 if len (attention_nodes ) > 0 :
864- layer_type = LayerType .ATTENTION
865881 ad_logger .debug (f"Found attention nodes in layer subgraph: { attention_nodes } " )
866882 if len (attention_nodes ) > 1 :
867883 # Column-row shard boundary region detection is probably wrong - there should be
868884 # only one attention operation. Fall back to simple shard.
869885 ad_logger .debug (f"More than one attention node: { attention_nodes } " )
870886 num_simple_shards += _process_simple_shard (
871- nodes_linear , rank , world_size , sharding_config
887+ nodes_linear , rank , world_size , transfrom_container , layer_type = layer_type
872888 )
873889 continue
874890 # Extract head dimension. We cannot shard below the head_dim size.
@@ -890,26 +906,27 @@ def detect_column_row_shard(
890906 f"Falling back to simple shard."
891907 )
892908 num_simple_shards += _process_simple_shard (
893- nodes_linear , rank , world_size , sharding_config
909+ nodes_linear ,
910+ rank ,
911+ world_size ,
912+ transfrom_container ,
913+ layer_type = layer_type ,
894914 )
895915 # TODO: handle the case where num_kv_heads is not divisible by world_size
896916 continue
897- num_attention_shards += 1
898- else :
899- layer_type = LayerType .MLP
900917
901918 # column-row sharding
902919 _process_column_sharding (
903920 linear_nodes = opening ,
904921 subgraph_nodes = layer_subgraph ,
905- transform_container = sharding_config ,
922+ transform_container = transfrom_container ,
906923 rank = rank ,
907924 world_size = world_size ,
908925 min_local_shape = min_local_shape ,
909926 )
910927
911928 # shard single row node
912- if sharding_config .add (
929+ if transfrom_container .add (
913930 WeightShardingInfo .from_node (
914931 closing ,
915932 split_dim = SplitDimension .ROW ,
@@ -921,10 +938,12 @@ def detect_column_row_shard(
921938 )
922939 ):
923940 num_column_row_shards += 1
941+ if layer_type == LayerType .ATTENTION :
942+ num_attention_shards += 1
924943
925944 # simple shard remaining linear nodes
926945 num_simple_shards += _process_simple_shard (
927- unprocessed_linear_nodes , rank , world_size , sharding_config
946+ unprocessed_linear_nodes , rank , world_size , transfrom_container
928947 )
929948 num_column_row_shards += num_ssm_shards
930949 ad_logger .info (
0 commit comments