3030from pydantic import BaseModel , ConfigDict , Field
3131from torch .fx import GraphModule , Node
3232
33- from ...models .factory import FactorySource
33+ from ...models .factory import ModelFactory , ShardingConfigSource
3434from ...utils .logger import ad_logger
3535from ...utils .node_utils import (
3636 extract_param_names_from_lin_node ,
@@ -255,7 +255,7 @@ def apply(self, gm: GraphModule, node: Node) -> None:
255255class ShardingConfig (BaseModel ):
256256 """Configuration for sharding the model."""
257257
258- factory_source : FactorySource
258+ factory_source : ShardingConfigSource
259259 rank : int
260260 world_size : int
261261 _predefined_config : Optional [Dict [str , Any ]] = None
@@ -267,7 +267,7 @@ class ShardingConfig(BaseModel):
267267
268268 def __init__ (
269269 self ,
270- factory_source : FactorySource ,
270+ factory_source : ShardingConfigSource ,
271271 rank : int ,
272272 world_size : int ,
273273 sharding_config : Dict [str , Any ] = None ,
@@ -290,30 +290,30 @@ def __init__(
290290 self .validate_config ()
291291
292292 def validate_config (self ) -> bool :
293- if self .factory_source != FactorySource .HUGGINGFACE :
293+ if self .factory_source != ShardingConfigSource .HUGGINGFACE :
294294 ad_logger .warning (
295295 "Sharding config is is currently only " + "supported for HuggingFace. Skipping."
296296 )
297297 # invalidate the config
298- self ._predefined_config = None
298+ self ._predefined_config = {}
299299 return False
300300
301301 if not isinstance (self ._predefined_config , dict ):
302302 ad_logger .warning ("Sharding config is not a dictionary. Skipping." )
303303 # invalidate the config
304- self ._predefined_config = None
304+ self ._predefined_config = {}
305305 return False
306306
307307 if "head_dim" not in self ._predefined_config :
308308 ad_logger .warning ("Sharding config does not contain head_dim. Skipping." )
309309 # invalidate the config
310- self ._predefined_config = None
310+ self ._predefined_config = {}
311311 return False
312312
313313 if "tp_plan" not in self ._predefined_config :
314314 ad_logger .warning ("Sharding config does not contain tp_plan. Skipping." )
315315 # invalidate the config
316- self ._predefined_config = None
316+ self ._predefined_config = {}
317317 return False
318318 tp_plan = self ._predefined_config ["tp_plan" ]
319319
@@ -333,7 +333,7 @@ def validate_config(self) -> bool:
333333 if not values .issubset (allowed_values ):
334334 ad_logger .warning ("Sharding config contains invalid values. Skipping." )
335335 # invalidate the config
336- self ._predefined_config = None
336+ self ._predefined_config = {}
337337 return False
338338 return True
339339
@@ -727,10 +727,26 @@ def _append_simple_shard(
727727 sharding_config .tp_transforms .extend (tp_shards )
728728
729729
730- def detect_sharding (gm : GraphModule , sharding_config : ShardingConfig ) -> None :
730+ def detect_sharding (
731+ gm : GraphModule ,
732+ factory : ModelFactory ,
733+ local_rank : int ,
734+ world_size : int ,
735+ simple_shard_only : bool ,
736+ use_sharding_from_factory : bool ,
737+ ) -> ShardingConfig :
738+ sharding_config = ShardingConfig (
739+ factory .get_sharding_config_source (),
740+ local_rank ,
741+ world_size ,
742+ factory .get_sharding_config (),
743+ simple_shard_only ,
744+ use_sharding_from_factory ,
745+ )
746+
731747 if (
732748 sharding_config .use_sharding_from_factory
733- and sharding_config .get_predefined_config () is not None
749+ and len ( sharding_config .get_predefined_config ()) > 0
734750 ):
735751 ad_logger .info ("Applying sharding from config" )
736752 detect_sharding_from_factory_config (gm , sharding_config )
@@ -746,6 +762,8 @@ def detect_sharding(gm: GraphModule, sharding_config: ShardingConfig) -> None:
746762 # run BMM sharding across ranks
747763 detect_dp_bmm_shard (gm , sharding_config )
748764
765+ return sharding_config
766+
749767
750768def detect_column_row_shard (
751769 gm : GraphModule ,
@@ -771,7 +789,7 @@ def detect_column_row_shard(
771789
772790 rank , world_size = sharding_config .rank , sharding_config .world_size
773791 if world_size < 2 :
774- ad_logger .info ("Skipping sharding for single device" )
792+ ad_logger .info ("Skipping TP sharding for single device" )
775793 return
776794
777795 assert isinstance (gm , GraphModule ), "Expecting GraphModule"
@@ -937,7 +955,7 @@ def detect_dp_bmm_shard(gm: GraphModule, sharding_config: ShardingConfig) -> Non
937955 ad_logger .debug ("Before sharding graph: " + str (gm ))
938956 rank , world_size = sharding_config .rank , sharding_config .world_size
939957 if world_size < 2 :
940- ad_logger .info ("Skipping sharding for single device" )
958+ ad_logger .info ("Skipping DP BMM sharding for single device" )
941959 return
942960
943961 assert isinstance (gm , GraphModule ), "Expecting GraphModule"
@@ -1008,7 +1026,7 @@ def detect_ep_shard(gm: GraphModule, sharding_config: ShardingConfig) -> None:
10081026
10091027 rank , world_size = sharding_config .rank , sharding_config .world_size
10101028 if world_size < 2 :
1011- ad_logger .info ("Skipping sharding for single device" )
1029+ ad_logger .info ("Skipping EP sharding for single device" )
10121030 return
10131031
10141032 assert isinstance (gm , GraphModule ), "Expecting GraphModule"
0 commit comments