3030from pydantic import BaseModel , ConfigDict , Field
3131from torch .fx import GraphModule , Node
3232
33+ from ...models .factory import FactorySource
3334from ...utils .logger import ad_logger
3435from ...utils .node_utils import (
3536 extract_param_names_from_lin_node ,
@@ -254,9 +255,10 @@ def apply(self, gm: GraphModule, node: Node) -> None:
254255class ShardingConfig (BaseModel ):
255256 """Configuration for sharding the model."""
256257
257- rank : int = 0
258- world_size : int = 1
259- predefined_config : Dict [str , Any ] = None
258+ factory_source : FactorySource
259+ rank : int
260+ world_size : int
261+ _predefined_config : Optional [Dict [str , Any ]] = None
260262 simple_shard_only : bool = False
261263 use_sharding_from_factory : bool = False
262264 tp_transforms : List [TPShardingInfo ] = Field (default_factory = list )
@@ -265,21 +267,81 @@ class ShardingConfig(BaseModel):
265267
266268 def __init__ (
267269 self ,
270+ factory_source : FactorySource ,
268271 rank : int ,
269272 world_size : int ,
270273 sharding_config : Dict [str , Any ] = None ,
271274 simple_shard_only : bool = False ,
272275 use_sharding_from_factory : bool = False ,
273276 ):
274- super ().__init__ ()
275- self .rank = rank
276- self .world_size = world_size
277- self .predefined_config = sharding_config
278- self .simple_shard_only = simple_shard_only
279- self .use_sharding_from_factory = use_sharding_from_factory
277+ super ().__init__ (
278+ factory_source = factory_source ,
279+ rank = rank ,
280+ world_size = world_size ,
281+ _predefined_config = sharding_config ,
282+ simple_shard_only = simple_shard_only ,
283+ use_sharding_from_factory = use_sharding_from_factory ,
284+ )
285+
286+ # Pydantic does not support setting private fields directly.
287+ self ._predefined_config = sharding_config
288+ # Validate the config after initialization
289+ if self ._predefined_config is not None :
290+ self .validate_config ()
291+
292+ def validate_config (self ) -> bool :
293+ if self .factory_source != FactorySource .HUGGINGFACE :
294+ ad_logger .warning (
295+ "Sharding config is is currently only " + "supported for HuggingFace. Skipping."
296+ )
297+ # invalidate the config
298+ self ._predefined_config = None
299+ return False
300+
301+ if not isinstance (self ._predefined_config , dict ):
302+ ad_logger .warning ("Sharding config is not a dictionary. Skipping." )
303+ # invalidate the config
304+ self ._predefined_config = None
305+ return False
306+
307+ if "head_dim" not in self ._predefined_config :
308+ ad_logger .warning ("Sharding config does not contain head_dim. Skipping." )
309+ # invalidate the config
310+ self ._predefined_config = None
311+ return False
312+
313+ if "tp_plan" not in self ._predefined_config :
314+ ad_logger .warning ("Sharding config does not contain tp_plan. Skipping." )
315+ # invalidate the config
316+ self ._predefined_config = None
317+ return False
318+ tp_plan = self ._predefined_config ["tp_plan" ]
319+
320+ values = set (tp_plan .values ())
321+ allowed_values = {
322+ "colwise" , # row split and no collective
323+ "rowwise" , # column split and all-reduce
324+ "gather" , # simple shard (row + all_gather)
325+ # TODO: remaining values are not supported yet.
326+ # They require hybrid EP+TP and/or SP support.
327+ # "sequence_parallel", # sequence parallelism
328+ # "local_colwise",
329+ # "local_rowwise",
330+ # "local_packed_rowwise",
331+ # "local",
332+ }
333+ if not values .issubset (allowed_values ):
334+ ad_logger .warning ("Sharding config contains invalid values. Skipping." )
335+ # invalidate the config
336+ self ._predefined_config = None
337+ return False
338+ return True
339+
340+ def get_predefined_config (self ) -> Dict [str , Any ]:
341+ return self ._predefined_config
280342
281343
282- def detect_tp_sharding_from_factory_config (
344+ def detect_sharding_from_factory_config (
283345 gm : GraphModule ,
284346 sharding_config : ShardingConfig ,
285347) -> None :
@@ -305,54 +367,30 @@ def detect_tp_sharding_from_factory_config(
305367 # The following constraints are based on
306368 # https://github.com/huggingface/transformers/blob/d8e05951b8efd4880acca9a3f291e8b65841a86d/src/transformers/models/llama4/configuration_llama4.py#L249
307369
308- if not isinstance (sharding_config .predefined_config , dict ):
309- ad_logger .warning ("Sharding config is not a dictionary. Skipping." )
310- return
311-
312- if "head_dim" not in sharding_config .predefined_config :
313- ad_logger .warning ("Sharding config does not contain head_dim. Skipping." )
314- return
315- head_dim = sharding_config .predefined_config ["head_dim" ]
316-
317- if "tp_plan" not in sharding_config .predefined_config :
318- ad_logger .warning ("Sharding config does not contain tp_plan. Skipping." )
319- return
320- tp_plan = sharding_config .predefined_config ["tp_plan" ]
321-
322- values = set (tp_plan .values ())
323- allowed_values = {
324- "colwise" ,
325- "rowwise" ,
326- "sequence_parallel" ,
327- "local_colwise" ,
328- "local_rowwise" ,
329- "local_packed_rowwise" ,
330- "local" ,
331- "gather" ,
332- }
333- if not values .issubset (allowed_values ):
334- ad_logger .warning ("Sharding config contains invalid values. Skipping." )
335- return
370+ factory_config = sharding_config .get_predefined_config ()
371+ head_dim = factory_config ["head_dim" ]
372+ tp_plan = factory_config ["tp_plan" ]
336373
337374 rank , world_size = sharding_config .rank , sharding_config .world_size
338375
376+ # If the node is inside the attention module, we need to set min_local_shape to the
377+ # head_dim - otherwise, we would risk splitting the heads into smaller shards.
378+ # TODO: is there a better way to check if we are in attention module?
379+ attn_names = [
380+ "attention" ,
381+ "Attention" ,
382+ "attn" ,
383+ "Attn" ,
384+ "q_proj" ,
385+ "k_proj" ,
386+ "v_proj" ,
387+ "o_proj" ,
388+ ]
389+
339390 for lin_node in filtered_nodes (gm .graph .nodes , is_linear_op ):
340391 # use node's weight name to get the module name
341392 module_name = lin_node .args [1 ].target
342393
343- # If the node is inside the attention module, we need to set min_local_shape to the
344- # head_dim - otherwise, we would risk splitting the heads into smaller shards.
345- # TODO: is there a better way to check if we are in attention module?
346- attn_names = [
347- "attention" ,
348- "Attention" ,
349- "attn" ,
350- "Attn" ,
351- "q_proj" ,
352- "k_proj" ,
353- "v_proj" ,
354- "o_proj" ,
355- ]
356394 if any (attn_name in module_name for attn_name in attn_names ):
357395 min_local_shape = head_dim
358396 else :
@@ -424,7 +462,8 @@ def simple_shard_first_n_layers(sharding_config: ShardingConfig, n_layers: int)
424462 # instead of "*".
425463 """
426464 new_tp_plan = {}
427- for layer_pattern , config in sharding_config .predefined_config ["tp_plan" ].items ():
465+ factory_config = sharding_config .get_predefined_config ()
466+ for layer_pattern , config in factory_config ["tp_plan" ].items ():
428467 if "*" in layer_pattern :
429468 # Create new dict with first n_layers entries first
430469
@@ -434,7 +473,7 @@ def simple_shard_first_n_layers(sharding_config: ShardingConfig, n_layers: int)
434473 # Add the default config after
435474 new_tp_plan [layer_pattern ] = config
436475
437- sharding_config .predefined_config ["tp_plan" ] = new_tp_plan
476+ sharding_config ._predefined_config ["tp_plan" ] = new_tp_plan
438477
439478
440479def simple_shard_last_n_layers (sharding_config : ShardingConfig , n_layers : int ) -> None :
@@ -446,8 +485,9 @@ def simple_shard_last_n_layers(sharding_config: ShardingConfig, n_layers: int) -
446485 # instead of "*".
447486 """
448487 new_tp_plan = {}
449- num_layers = sharding_config .predefined_config ["num_hidden_layers" ]
450- for layer_pattern , config in sharding_config .predefined_config ["tp_plan" ].items ():
488+ factory_config = sharding_config .get_predefined_config ()
489+ num_layers = factory_config ["num_hidden_layers" ]
490+ for layer_pattern , config in factory_config ["tp_plan" ].items ():
451491 if "*" in layer_pattern :
452492 # Create new dict with first n_layers entries first
453493
@@ -456,18 +496,18 @@ def simple_shard_last_n_layers(sharding_config: ShardingConfig, n_layers: int) -
456496
457497 # Add the default config after
458498 new_tp_plan [layer_pattern ] = config
459- sharding_config .predefined_config ["tp_plan" ] = new_tp_plan
499+ sharding_config ._predefined_config ["tp_plan" ] = new_tp_plan
460500
461501
462502def simple_shard_attention_layers (sharding_config : ShardingConfig ) -> None :
463503 """
464504 If any key in tp_plan contains "attention", replace it with "gather"
465505 """
466- for layer_pattern , config in sharding_config .predefined_config ["tp_plan" ].items ():
506+ for layer_pattern , config in sharding_config ._predefined_config ["tp_plan" ].items ():
467507 if any (
468508 attn_name in layer_pattern for attn_name in ["attention" , "Attention" , "attn" , "Attn" ]
469509 ):
470- sharding_config .predefined_config ["tp_plan" ][layer_pattern ] = "gather"
510+ sharding_config ._predefined_config ["tp_plan" ][layer_pattern ] = "gather"
471511
472512
473513def sharding_transform_executor (gm : GraphModule , sharding_config : ShardingConfig ) -> None :
@@ -687,6 +727,26 @@ def _append_simple_shard(
687727 sharding_config .tp_transforms .extend (tp_shards )
688728
689729
730+ def detect_sharding (gm : GraphModule , sharding_config : ShardingConfig ) -> None :
731+ if (
732+ sharding_config .use_sharding_from_factory
733+ and sharding_config .get_predefined_config () is not None
734+ ):
735+ ad_logger .info ("Applying sharding from config" )
736+ detect_sharding_from_factory_config (gm , sharding_config )
737+ return
738+
739+ ad_logger .info ("Running autodeploy sharding heuristics" )
740+ # run TP sharding across ranks
741+ detect_column_row_shard (gm , sharding_config )
742+
743+ # run EP sharding across ranks
744+ detect_ep_shard (gm , sharding_config )
745+
746+ # run BMM sharding across ranks
747+ detect_dp_bmm_shard (gm , sharding_config )
748+
749+
690750def detect_column_row_shard (
691751 gm : GraphModule ,
692752 sharding_config : ShardingConfig ,
@@ -716,11 +776,6 @@ def detect_column_row_shard(
716776
717777 assert isinstance (gm , GraphModule ), "Expecting GraphModule"
718778
719- if sharding_config .use_sharding_from_factory and sharding_config .predefined_config is not None :
720- ad_logger .info ("Using TP sharding from config" )
721- detect_tp_sharding_from_factory_config (gm , sharding_config )
722- return
723-
724779 ad_logger .info ("Running TP sharding detection" )
725780
726781 # find boundary nodes of regions we want to shard
0 commit comments