2222from typing import DefaultDict , Dict , List , Set , Tuple , Type
2323
2424import torch
25- from pydantic import Field
25+ from pydantic import Field , field_validator
2626from torch .fx import GraphModule , Node
2727
28+ from .....functional import AllReduceStrategy
2829from ...models .factory import ModelFactory , ShardingConfigSource
2930from ...shim .interface import CachedSequenceInterface
3031from ...utils .logger import ad_logger
4950 SplitDimension ,
5051 WeightShardingInfo ,
5152 get_all_weights_in_subgraph ,
53+ validate_allreduce_strategy ,
5254)
5355from ..interface import (
5456 BaseTransform ,
@@ -152,6 +154,18 @@ class ShardingTransformConfig(TransformConfig):
152154 sharding_dims : List [ShardingDim ] = Field (
153155 default_factory = lambda : [ShardingDim .SSM , ShardingDim .TP , ShardingDim .EP , ShardingDim .BMM ]
154156 )
157+ allreduce_strategy : AllReduceStrategy = Field (
158+ default = AllReduceStrategy .AUTO ,
159+ description = "AllReduce strategy for distributed operations. "
160+ "Options: AUTO (automatic selection), NCCL, ONESHOT, TWOSHOT, MIN_LATENCY, "
161+ "LOWPRECISION, UB, MNNVL, NCCL_SYMMETRIC" ,
162+ )
163+
164+ @field_validator ("allreduce_strategy" , mode = "before" )
165+ @classmethod
166+ def _validate_allreduce_strategy (cls , v ):
167+ """Convert string names like 'AUTO' to AllReduceStrategy enum."""
168+ return validate_allreduce_strategy (v )
155169
156170
157171@TransformRegistry .register ("detect_sharding" )
@@ -199,6 +213,8 @@ def _apply(
199213 sharding_config = shared_config .sharding_config
200214 sharding_config .rank = local_rank
201215 sharding_config .world_size = world_size
216+ sharding_config .allreduce_strategy = self .config .allreduce_strategy
217+ ad_logger .info (f"Using allreduce strategy: { sharding_config .allreduce_strategy .name } " )
202218 sharding_config .predefined_config = factory .get_sharding_config () if factory else {}
203219 sharding_config .factory_source = (
204220 sharding_config .predefined_config .get ("source" , ShardingConfigSource .UNKNOWN )
@@ -573,7 +589,7 @@ def detect_sharding_from_factory_config(
573589 # we have a match. Get the config for this layer
574590 config = tp_plan [key ]
575591 if config == "colwise" :
576- sharding_config .weight_sharding_transforms . append (
592+ if sharding_config .add (
577593 WeightShardingInfo .from_node (
578594 lin_node ,
579595 split_dim = SplitDimension .COLUMN ,
@@ -582,10 +598,10 @@ def detect_sharding_from_factory_config(
582598 dist_op = None ,
583599 min_local_shape = min_local_shape ,
584600 )
585- )
586- num_row_col_shards += 1
601+ ):
602+ num_row_col_shards += 1
587603 elif config == "rowwise" :
588- sharding_config .weight_sharding_transforms . append (
604+ if sharding_config .add (
589605 WeightShardingInfo .from_node (
590606 lin_node ,
591607 split_dim = SplitDimension .ROW ,
@@ -594,10 +610,10 @@ def detect_sharding_from_factory_config(
594610 dist_op = "all_reduce" ,
595611 min_local_shape = min_local_shape ,
596612 )
597- )
598- num_row_col_shards += 1
613+ ):
614+ num_row_col_shards += 1
599615 elif config == "mamba" :
600- sharding_config .weight_sharding_transforms . append (
616+ sharding_config .add (
601617 WeightShardingInfo .from_node (
602618 lin_node ,
603619 split_dim = SplitDimension .COLUMN ,
@@ -618,7 +634,7 @@ def detect_sharding_from_factory_config(
618634 if "shared" in module_name :
619635 col_row_action = config .replace ("local_" , "" )
620636 if col_row_action == "colwise" :
621- sharding_config .weight_sharding_transforms . append (
637+ sharding_config .add (
622638 WeightShardingInfo (
623639 target_node = lin_node .name ,
624640 split_dim = SplitDimension .COLUMN ,
@@ -629,7 +645,7 @@ def detect_sharding_from_factory_config(
629645 )
630646 )
631647 elif col_row_action == "rowwise" :
632- sharding_config .weight_sharding_transforms . append (
648+ if sharding_config .add (
633649 WeightShardingInfo (
634650 target_node = lin_node .name ,
635651 split_dim = SplitDimension .ROW ,
@@ -638,8 +654,8 @@ def detect_sharding_from_factory_config(
638654 dist_op = "all_reduce" ,
639655 min_local_shape = min_local_shape ,
640656 )
641- )
642- num_row_col_shards += 1
657+ ):
658+ num_row_col_shards += 1
643659 else :
644660 ad_logger .warning (f"Unsupported sharding action { config } . Skipping." )
645661 else :
@@ -648,7 +664,7 @@ def detect_sharding_from_factory_config(
648664
649665 elif "gather" in config :
650666 # Simple shard (row + all_gather)
651- sharding_config .weight_sharding_transforms . append (
667+ if sharding_config .add (
652668 WeightShardingInfo .from_node (
653669 lin_node ,
654670 split_dim = SplitDimension .COLUMN ,
@@ -657,13 +673,13 @@ def detect_sharding_from_factory_config(
657673 dist_op = "all_gather" ,
658674 min_local_shape = 1 ,
659675 )
660- )
661- num_simple_shards += 1
676+ ):
677+ num_simple_shards += 1
662678 else :
663679 ad_logger .warning (
664680 f"Unsupported sharding action { config } . Fallback to simple shard"
665681 )
666- sharding_config .weight_sharding_transforms . append (
682+ sharding_config .add (
667683 WeightShardingInfo .from_node (
668684 lin_node ,
669685 split_dim = SplitDimension .COLUMN ,
@@ -943,7 +959,7 @@ def detect_column_row_shard(
943959 )
944960
945961 # shard single row node
946- sharding_config .weight_sharding_transforms . append (
962+ if sharding_config .add (
947963 WeightShardingInfo .from_node (
948964 nodes_to_row_shard [0 ],
949965 split_dim = SplitDimension .ROW ,
@@ -952,9 +968,8 @@ def detect_column_row_shard(
952968 dist_op = "all_reduce" ,
953969 min_local_shape = min_local_shape ,
954970 )
955- )
956-
957- num_row_col_shards += 1
971+ ):
972+ num_row_col_shards += 1
958973
959974 ad_logger .info (
960975 f"Found { num_shards } TP shards (simple: { num_simple_shards } , row-col: { num_row_col_shards } )"
@@ -1020,7 +1035,7 @@ def detect_dp_bmm_shard(gm: GraphModule, sharding_config: ShardingConfig) -> Tra
10201035 start_idx = remainder + rank * base_size
10211036 end_idx = start_idx + base_size
10221037
1023- sharding_config .bmm_transforms . append (
1038+ sharding_config .add (
10241039 BMMShardingInfo (
10251040 target_node = node .name ,
10261041 rank = rank ,
@@ -1064,14 +1079,14 @@ def detect_ep_shard(gm: GraphModule, sharding_config: ShardingConfig) -> Transfo
10641079 ),
10651080 ):
10661081 continue
1067- sharding_config .ep_transforms . append (
1082+ if sharding_config .add (
10681083 EPShardingInfo .from_node (
10691084 node ,
10701085 rank = rank ,
10711086 world_size = world_size ,
10721087 )
1073- )
1074- num_moe_patterns += 1
1088+ ):
1089+ num_moe_patterns += 1
10751090
10761091 ad_logger .info (f"Found { num_moe_patterns } MoE patterns" )
10771092
0 commit comments