@@ -44,11 +44,37 @@ def build_model_and_optimizer(
4444 device : torch .device ,
4545 dtype : torch .dtype ,
4646 cpu_offload : bool = False ,
47- fsdp_cfg : Dict [str , Any ] = {},
47+ fsdp_cfg : Optional [Dict [str , Any ]] = None ,
48+ ddp_cfg : Optional [Dict [str , Any ]] = None ,
4849 attention_backend : Optional [str ] = None ,
4950 optimizer_cfg : Optional [Dict [str , Any ]] = None ,
5051) -> tuple [NeMoWanPipeline , dict [str , Dict [str , Any ]], torch .optim .Optimizer , Any ]:
51- """Build the diffusion model, parallel scheme, and optimizer."""
52+ """Build the diffusion model, parallel scheme, and optimizer.
53+
54+ Args:
55+ model_id: Pretrained model name or path.
56+ finetune_mode: Whether to load for finetuning.
57+ learning_rate: Learning rate for optimizer.
58+ device: Target device.
59+ dtype: Model dtype.
60+ cpu_offload: Whether to enable CPU offload (FSDP only).
61+ fsdp_cfg: FSDP configuration dict. Mutually exclusive with ddp_cfg.
62+ ddp_cfg: DDP configuration dict. Mutually exclusive with fsdp_cfg.
63+ attention_backend: Optional attention backend override.
64+ optimizer_cfg: Optional optimizer configuration.
65+
66+ Returns:
67+ Tuple of (pipeline, optimizer, device_mesh or None).
68+
69+ Raises:
70+ ValueError: If both fsdp_cfg and ddp_cfg are provided.
71+ """
72+ # Validate mutually exclusive configs
73+ if fsdp_cfg is not None and ddp_cfg is not None :
74+ raise ValueError (
75+ "Cannot specify both 'fsdp' and 'ddp' configurations. "
76+ "Please provide only one distributed training strategy."
77+ )
5278
5379 logging .info ("[INFO] Building NeMoAutoDiffusionPipeline with transformer parallel scheme..." )
5480
@@ -57,26 +83,42 @@ def build_model_and_optimizer(
5783
5884 world_size = dist .get_world_size () if dist .is_initialized () else 1
5985
60- if fsdp_cfg .get ("dp_size" , None ) is None :
61- denom = max (1 , fsdp_cfg .get ("tp_size" , 1 ) * fsdp_cfg .get ("cp_size" , 1 ) * fsdp_cfg .get ("pp_size" , 1 ))
62- fsdp_cfg .dp_size = max (1 , world_size // denom )
63-
64- manager_args : Dict [str , Any ] = {
65- "dp_size" : fsdp_cfg .get ("dp_size" , None ),
66- "dp_replicate_size" : fsdp_cfg .get ("dp_replicate_size" , None ),
67- "tp_size" : fsdp_cfg .get ("tp_size" , 1 ),
68- "cp_size" : fsdp_cfg .get ("cp_size" , 1 ),
69- "pp_size" : fsdp_cfg .get ("pp_size" , 1 ),
70- "backend" : "nccl" ,
71- "world_size" : world_size ,
72- "use_hf_tp_plan" : fsdp_cfg .get ("use_hf_tp_plan" , False ),
73- "activation_checkpointing" : True ,
74- "mp_policy" : MixedPrecisionPolicy (
75- param_dtype = dtype ,
76- reduce_dtype = torch .float32 ,
77- output_dtype = dtype ,
78- ),
79- }
86+ # Build manager args based on which config is provided
87+ if ddp_cfg is not None :
88+ # DDP configuration
89+ logging .info ("[INFO] Using DDP (DistributedDataParallel) for training" )
90+ manager_args : Dict [str , Any ] = {
91+ "_manager_type" : "ddp" ,
92+ "backend" : ddp_cfg .get ("backend" , "nccl" ),
93+ "world_size" : world_size ,
94+ "activation_checkpointing" : ddp_cfg .get ("activation_checkpointing" , False ),
95+ }
96+ else :
97+ # FSDP configuration (default)
98+ fsdp_cfg = fsdp_cfg or {}
99+ logging .info ("[INFO] Using FSDP2 (Fully Sharded Data Parallel) for training" )
100+
101+ if fsdp_cfg .get ("dp_size" , None ) is None :
102+ denom = max (1 , fsdp_cfg .get ("tp_size" , 1 ) * fsdp_cfg .get ("cp_size" , 1 ) * fsdp_cfg .get ("pp_size" , 1 ))
103+ fsdp_cfg ["dp_size" ] = max (1 , world_size // denom )
104+
105+ manager_args : Dict [str , Any ] = {
106+ "_manager_type" : "fsdp2" ,
107+ "dp_size" : fsdp_cfg .get ("dp_size" , None ),
108+ "dp_replicate_size" : fsdp_cfg .get ("dp_replicate_size" , None ),
109+ "tp_size" : fsdp_cfg .get ("tp_size" , 1 ),
110+ "cp_size" : fsdp_cfg .get ("cp_size" , 1 ),
111+ "pp_size" : fsdp_cfg .get ("pp_size" , 1 ),
112+ "backend" : "nccl" ,
113+ "world_size" : world_size ,
114+ "use_hf_tp_plan" : fsdp_cfg .get ("use_hf_tp_plan" , False ),
115+ "activation_checkpointing" : fsdp_cfg .get ("activation_checkpointing" , True ),
116+ "mp_policy" : MixedPrecisionPolicy (
117+ param_dtype = dtype ,
118+ reduce_dtype = torch .float32 ,
119+ output_dtype = dtype ,
120+ ),
121+ }
80122
81123 parallel_scheme = {"transformer" : manager_args }
82124
@@ -194,10 +236,19 @@ def setup(self):
194236 logging .info (f"[INFO] Node rank: { self .node_rank } , Local rank: { self .local_rank } " )
195237 logging .info (f"[INFO] Learning rate: { self .learning_rate } " )
196238
197- fsdp_cfg = self .cfg .get ("fsdp" , {})
239+ # Get distributed training configs (mutually exclusive)
240+ fsdp_cfg = self .cfg .get ("fsdp" , None )
241+ ddp_cfg = self .cfg .get ("ddp" , None )
198242 fm_cfg = self .cfg .get ("flow_matching" , {})
199243
200- self .cpu_offload = fsdp_cfg .get ("cpu_offload" , False )
244+ # Validate mutually exclusive distributed configs
245+ if fsdp_cfg is not None and ddp_cfg is not None :
246+ raise ValueError (
247+ "Cannot specify both 'fsdp' and 'ddp' configurations in YAML. "
248+ "Please provide only one distributed training strategy."
249+ )
250+
251+ self .cpu_offload = fsdp_cfg .get ("cpu_offload" , False ) if fsdp_cfg else False
201252
202253 # Flow matching configuration
203254 self .adapter_type = fm_cfg .get ("adapter_type" , "simple" )
@@ -233,6 +284,7 @@ def setup(self):
233284 dtype = self .bf16 ,
234285 cpu_offload = self .cpu_offload ,
235286 fsdp_cfg = fsdp_cfg ,
287+ ddp_cfg = ddp_cfg ,
236288 optimizer_cfg = self .cfg .get ("optim.optimizer" , {}),
237289 attention_backend = self .attention_backend ,
238290 )
@@ -288,13 +340,19 @@ def setup(self):
288340 raise RuntimeError ("Training dataloader is empty; cannot proceed with training" )
289341
290342 # Derive DP size consistent with model parallel config
291- tp_size = fsdp_cfg .get ("tp_size" , 1 )
292- cp_size = fsdp_cfg .get ("cp_size" , 1 )
293- pp_size = fsdp_cfg .get ("pp_size" , 1 )
294- denom = max (1 , tp_size * cp_size * pp_size )
295- self .dp_size = fsdp_cfg .get ("dp_size" , None )
296- if self .dp_size is None :
297- self .dp_size = max (1 , self .world_size // denom )
343+ if ddp_cfg is not None :
344+ # DDP uses pure data parallelism across all ranks
345+ self .dp_size = self .world_size
346+ else :
347+ # FSDP may have TP/CP/PP dimensions
348+ _fsdp_cfg = fsdp_cfg or {}
349+ tp_size = _fsdp_cfg .get ("tp_size" , 1 )
350+ cp_size = _fsdp_cfg .get ("cp_size" , 1 )
351+ pp_size = _fsdp_cfg .get ("pp_size" , 1 )
352+ denom = max (1 , tp_size * cp_size * pp_size )
353+ self .dp_size = _fsdp_cfg .get ("dp_size" , None )
354+ if self .dp_size is None :
355+ self .dp_size = max (1 , self .world_size // denom )
298356
299357 # Infer local micro-batch size from dataloader if available
300358 self .local_batch_size = self .cfg .step_scheduler .local_batch_size
@@ -449,3 +507,21 @@ def run_train_validation_loop(self):
449507 wandb .finish ()
450508
451509 logging .info ("[INFO] Training complete!" )
510+
511+ def _get_dp_rank (self , include_cp : bool = False ) -> int :
512+ """Get data parallel rank, handling DDP mode where device_mesh is None."""
513+ # In DDP mode, device_mesh is None, so use torch.distributed directly
514+ device_mesh = getattr (self , "device_mesh" , None )
515+ if device_mesh is None :
516+ return dist .get_rank () if dist .is_initialized () else 0
517+ # Otherwise, use the parent implementation
518+ return super ()._get_dp_rank (include_cp = include_cp )
519+
520+ def _get_dp_group_size (self , include_cp : bool = False ) -> int :
521+ """Get data parallel world size, handling DDP mode where device_mesh is None."""
522+ # In DDP mode, device_mesh is None, so use torch.distributed directly
523+ device_mesh = getattr (self , "device_mesh" , None )
524+ if device_mesh is None :
525+ return dist .get_world_size () if dist .is_initialized () else 1
526+ # Otherwise, use the parent implementation
527+ return super ()._get_dp_group_size (include_cp = include_cp )
0 commit comments