3535from torchrec .distributed .planner import EmbeddingShardingPlanner , Topology
3636from torchrec .distributed .sharding_plan import get_default_sharders
3737from torchrec .distributed .types import (
38+ DMPCollectionConfig ,
39+ DMPCollectionContext ,
3840 EnumerableShardingSpec ,
3941 ModuleSharder ,
4042 ParameterSharding ,
@@ -404,7 +406,6 @@ def _shard_modules_impl(
404406 module : nn .Module ,
405407 path : str = "" ,
406408 ) -> nn .Module :
407-
408409 # pre-sharded module
409410 if isinstance (module , ShardedModule ):
410411 return module
@@ -827,53 +828,150 @@ def __init__(
827828 data_parallel_wrapper : Optional [DataParallelWrapper ] = None ,
828829 use_inter_host_allreduce : bool = False ,
829830 custom_all_reduce : Optional [Callable [[List [torch .Tensor ]], None ]] = None ,
831+ submodule_configs : Optional [List [DMPCollectionConfig ]] = None ,
830832 ) -> None :
831833 assert (
832834 device .type == "cuda" or device .type == "mtia"
833835 ), "DMPCollection only supports CUDA or MTIA"
834836 self ._device = device
835837 self ._pg : dist .ProcessGroup = global_pg
836- self ._plan : ShardingPlan = plan
837- self ._device_mesh : DeviceMesh = None # pyre-ignore[8]
838- self ._sharding_pg : dist .ProcessGroup = None # pyre-ignore[8]
839- self ._replica_pg : dist .ProcessGroup = None # pyre-ignore[8]
840838 self ._global_rank : int = dist .get_rank (global_pg )
841839 self ._custom_all_reduce = custom_all_reduce
842840
843- self ._device_mesh , self ._sharding_pg , self ._replica_pg = (
844- self ._create_process_groups (
841+ if sharders is None :
842+ sharders = get_default_sharders ()
843+ self ._sharder_map : Dict [Type [nn .Module ], ModuleSharder [nn .Module ]] = {
844+ sharder .module_type : sharder for sharder in sharders
845+ }
846+
847+ # the args provided by the users are used for default modules
848+ # default context is index 0, TODO - if cleaner way to distinguish
849+ self ._ctxs : List [DMPCollectionContext ] = [
850+ DMPCollectionContext (
851+ # default context has module type None
852+ module = None , # pyre-ignore[6]
853+ plan = plan ,
854+ sharding_group_size = sharding_group_size ,
855+ node_group_size = node_group_size ,
856+ use_inter_host_allreduce = use_inter_host_allreduce ,
857+ )
858+ ]
859+
860+ if submodule_configs is not None :
861+ for submodule_config in submodule_configs :
862+ self ._ctxs .append (
863+ DMPCollectionContext (
864+ module = submodule_config .module ,
865+ plan = submodule_config .plan ,
866+ sharding_group_size = submodule_config .sharding_group_size ,
867+ use_inter_host_allreduce = submodule_config .use_inter_host_allreduce ,
868+ )
869+ )
870+
871+ # create process groups and remap sharding plans per module context
872+ for ctx in self ._ctxs :
873+ (
874+ device_mesh ,
875+ sharding_pg ,
876+ replica_pg ,
877+ ) = self ._create_process_groups (
845878 global_rank = self ._global_rank ,
846879 world_size = world_size ,
847- local_size = sharding_group_size ,
848- use_inter_host_allreduce = use_inter_host_allreduce ,
880+ local_size = ctx . sharding_group_size ,
881+ use_inter_host_allreduce = ctx . use_inter_host_allreduce ,
849882 )
883+
884+ ctx .device_mesh = device_mesh
885+ ctx .sharding_pg = sharding_pg
886+ ctx .replica_pg = replica_pg
887+
888+ step = world_size // ctx .sharding_group_size
889+ self ._remap_sharding_plan (
890+ plan = ctx .plan ,
891+ rank = self ._global_rank ,
892+ step = step ,
893+ sharding_group_size = ctx .sharding_group_size ,
894+ use_inter_host_allreduce = ctx .use_inter_host_allreduce ,
895+ )
896+
897+ if ctx .module :
898+ ctx .sharded_module = self ._sharder_map [ctx .module ].sharded_module_type # pyre-ignore[16]
899+
900+ consolidated_plan = copy .deepcopy (self ._ctxs [0 ].plan )
901+ for ctx in self ._ctxs [1 :]:
902+ for key , val in ctx .plan .plan .items ():
903+ consolidated_plan .plan [key ] = copy .deepcopy (val )
904+
905+ logger .info (
906+ "[TorchRec 2D Parallel] Consolidated sharding plan:\n %s" , consolidated_plan
850907 )
851908
852- self . _remap_sharding_plan (
853- plan = plan ,
854- rank = self ._global_rank ,
855- step = world_size // sharding_group_size ,
856- sharding_group_size = sharding_group_size ,
857- use_inter_host_allreduce = use_inter_host_allreduce ,
909+ default_env = ShardingEnv2D (
910+ global_pg = self . _pg ,
911+ sharding_pg = self ._ctxs [ 0 ]. sharding_pg ,
912+ device_mesh = self . _ctxs [ 0 ]. device_mesh ,
913+ node_group_size = node_group_size ,
914+ use_inter_host_allreduce = self . _ctxs [ 0 ]. use_inter_host_allreduce ,
858915 )
859- super ().__init__ (
916+
917+ super ().__init__ ( # type: ignore[misc]
860918 module ,
861- ShardingEnv2D (
862- global_pg = self ._pg ,
863- sharding_pg = self ._sharding_pg ,
864- device_mesh = self ._device_mesh ,
865- node_group_size = node_group_size ,
866- use_inter_host_allreduce = use_inter_host_allreduce ,
867- ),
919+ default_env ,
868920 device ,
869- plan ,
921+ consolidated_plan ,
870922 sharders ,
871923 init_data_parallel ,
872924 init_parameters ,
873925 data_parallel_wrapper ,
874926 )
875- # post DMP init, we group sharded modules for parameter sync
876- self ._modules_to_sync : List [nn .Module ] = self ._group_sharded_modules ()
927+
928+ # post DMP init, we group sharded modules for parameter sync, stored in the context
929+ self ._group_sharded_modules (self ._ctxs )
930+
931+ def _shard_modules_impl (
932+ self ,
933+ module : nn .Module ,
934+ path : str = "" ,
935+ ) -> nn .Module :
936+
937+ # pre-sharded module
938+ if isinstance (module , ShardedModule ):
939+ return module
940+
941+ # shardable module
942+ module_sharding_plan = self ._plan .get_plan_for_module (path )
943+ if module_sharding_plan :
944+ env = self ._env
945+ sharder_key = type (module )
946+
947+ for ctx in self ._ctxs [1 :]:
948+ if ctx .module == sharder_key :
949+ env = ShardingEnv2D (
950+ global_pg = self ._pg ,
951+ sharding_pg = ctx .sharding_pg ,
952+ device_mesh = ctx .device_mesh ,
953+ node_group_size = ctx .sharding_group_size ,
954+ use_inter_host_allreduce = ctx .use_inter_host_allreduce ,
955+ )
956+ break
957+
958+ module = self ._sharder_map [sharder_key ].shard (
959+ module ,
960+ module_sharding_plan ,
961+ env ,
962+ self .device ,
963+ path ,
964+ )
965+ return module
966+
967+ for name , child in module .named_children ():
968+ child = self ._shard_modules_impl (
969+ child ,
970+ path + "." + name if path else name ,
971+ )
972+ setattr (module , name , child )
973+
974+ return module
877975
878976 def sync (self , include_optimizer_state : bool = True ) -> None :
879977 """
@@ -888,10 +986,24 @@ def sync(self, include_optimizer_state: bool = True) -> None:
888986 Args:
889987 include_optimizer_state (bool): Flag to include optimizer state syncing upon call
890988 """
891- assert self ._replica_pg is not None , "replica_pg is not initialized!"
989+ # we sync per context to use the right all reduce process group
990+ for ctx in self ._ctxs :
991+ self ._sync (
992+ ctx .replica_pg ,
993+ ctx .modules_to_sync ,
994+ include_optimizer_state ,
995+ )
996+
997+ def _sync (
998+ self ,
999+ replica_pg : dist .ProcessGroup ,
1000+ modules_to_sync : List [nn .Module ],
1001+ include_optimizer_state : bool = True ,
1002+ ) -> None :
1003+ assert replica_pg is not None , "replica_pg is not initialized!"
8921004 all_weights_by_dtype : dict [torch .dtype , List [torch .Tensor ]] = defaultdict (list )
8931005
894- for emb_kernel in self . _modules_to_sync :
1006+ for emb_kernel in modules_to_sync :
8951007 # pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
8961008 for w in emb_kernel .split_embedding_weights ():
8971009 all_weights_by_dtype [w .dtype ].append (w )
@@ -900,25 +1012,31 @@ def sync(self, include_optimizer_state: bool = True) -> None:
9001012 if self ._custom_all_reduce is None :
9011013 opts = dist .AllreduceCoalescedOptions ()
9021014 opts .reduceOp = dist .ReduceOp .AVG
903- self ._allreduce_tensors (all_weights_by_dtype , "## 2d_weight_sync ##" , opts )
1015+ self ._allreduce_tensors (
1016+ replica_pg , all_weights_by_dtype , "## 2d_weight_sync ##" , opts
1017+ )
9041018
9051019 if include_optimizer_state :
9061020 optimizer_tensors_by_dtype : Dict [torch .dtype , List [torch .Tensor ]] = (
9071021 defaultdict (list )
9081022 )
909- for emb_kernel in self . _modules_to_sync :
1023+ for emb_kernel in modules_to_sync :
9101024 # pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
9111025 optimizer_states = emb_kernel .get_optimizer_state ()
9121026 for state in optimizer_states :
9131027 opt_tensor = state ["sum" ]
9141028 optimizer_tensors_by_dtype [opt_tensor .dtype ].append (opt_tensor )
9151029 if optimizer_tensors_by_dtype :
9161030 self ._allreduce_tensors (
917- optimizer_tensors_by_dtype , "## 2d_optimizer_sync ##" , opts
1031+ replica_pg ,
1032+ optimizer_tensors_by_dtype ,
1033+ "## 2d_optimizer_sync ##" ,
1034+ opts ,
9181035 )
9191036
9201037 def _allreduce_tensors (
9211038 self ,
1039+ pg : dist .ProcessGroup ,
9221040 tensors_dict : Dict [torch .dtype , List [torch .Tensor ]],
9231041 annotation : str ,
9241042 opts : Optional [dist .AllreduceCoalescedOptions ] = None ,
@@ -939,7 +1057,7 @@ def _all_reduce(tensors: List[torch.Tensor]) -> None:
9391057
9401058 def _all_reduce (tensors : List [torch .Tensor ]) -> None :
9411059 with record_function (annotation ):
942- self . _replica_pg .allreduce_coalesced (tensors , opts = opts ).wait ()
1060+ pg .allreduce_coalesced (tensors , opts = opts ).wait ()
9431061
9441062 for tensor_list in tensors_dict .values ():
9451063 _all_reduce (tensor_list )
@@ -1073,7 +1191,38 @@ def _remap_sharding_plan(
10731191
10741192 def _group_sharded_modules (
10751193 self ,
1194+ contexts : List [DMPCollectionContext ],
1195+ ) -> None :
1196+ # Post init DMP, save the embedding kernels, with respect to contexts
1197+ for context in contexts [1 :]:
1198+ context .modules_to_sync = self ._group_sharded_module (context .sharded_module ) # pyre-ignore[6]
1199+
1200+ # Group leftover embedding kernels, with respect to default context
1201+ modules_to_skip : List [nn .Module ] = [c .sharded_module for c in contexts [1 :]] # pyre-ignore[9]
1202+ sharded_modules : List [nn .Module ] = []
1203+
1204+ def _find_sharded_modules (
1205+ module : nn .Module ,
1206+ ) -> None :
1207+ if isinstance (module , SplitTableBatchedEmbeddingBagsCodegen ):
1208+ sharded_modules .append (module )
1209+ if not isinstance (
1210+ module , tuple (modules_to_skip ) # pyre-ignore[6]
1211+ ) and hasattr (module , "_lookups" ):
1212+ for lookup in module ._lookups : # pyre-ignore[29]
1213+ _find_sharded_modules (lookup )
1214+
1215+ for _ , child in module .named_children ():
1216+ _find_sharded_modules (child )
1217+
1218+ _find_sharded_modules (self ._dmp_wrapped_module )
1219+ contexts [0 ].modules_to_sync = sharded_modules
1220+
1221+ def _group_sharded_module (
1222+ self ,
1223+ sharded_module : nn .Module ,
10761224 ) -> List [nn .Module ]:
1225+ # Traverse module and find all sharded module kernels matching the sharded module
10771226 # Post init DMP, save the embedding kernels
10781227 sharded_modules : List [nn .Module ] = []
10791228
@@ -1082,36 +1231,20 @@ def _find_sharded_modules(
10821231 ) -> None :
10831232 if isinstance (module , SplitTableBatchedEmbeddingBagsCodegen ):
10841233 sharded_modules .append (module )
1085- if hasattr (module , "_lookups" ):
1086- # pyre-fixme[29]: `Union[(self: Tensor) -> Any, Module, Tensor]` is
1087- # not a function.
1088- for lookup in module ._lookups :
1234+ if isinstance (module , sharded_module ): # pyre-ignore[6]
1235+ for lookup in module ._lookups : # pyre-ignore[29]
10891236 _find_sharded_modules (lookup )
1090- return
1237+
10911238 for _ , child in module .named_children ():
10921239 _find_sharded_modules (child )
10931240
10941241 _find_sharded_modules (self ._dmp_wrapped_module )
10951242 return sharded_modules
10961243
1097- @property
1098- def sharding_pg (self ) -> dist .ProcessGroup :
1099- """
1100- Returns the process group used for this ranks sharding.
1101- """
1102- return self ._sharding_pg
1103-
1104- @property
1105- def replica_pg (self ) -> dist .ProcessGroup :
1106- """
1107- Returns the process group used for this ranks replication.
1108- """
1109- return self ._replica_pg
1110-
11111244 @property
11121245 def device_mesh (self ) -> DeviceMesh :
11131246 """
11141247 Returns the device mesh used for 2D parallelism.
11151248 Contains two dimensions: "replicate" and "shard".
11161249 """
1117- return self ._device_mesh
1250+ return self ._ctxs [ 0 ]. device_mesh
0 commit comments