Skip to content

Commit ff8d0c3

Browse files
committed
feat: Add DDP support in the Automodel path
Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com>
1 parent 9eaace1 commit ff8d0c3

File tree

6 files changed

+222
-49
lines changed

6 files changed

+222
-49
lines changed

dfm/src/automodel/_diffusers/auto_diffusion_pipeline.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
import copy
1616
import logging
1717
import os
18-
from typing import Any, Dict, Iterable, Optional, Tuple
18+
from typing import Any, Dict, Iterable, Optional, Tuple, Union
1919

2020
import torch
2121
import torch.nn as nn
2222
from diffusers import DiffusionPipeline, WanPipeline
2323
from nemo_automodel.components.distributed import parallelizer
24+
from nemo_automodel.components.distributed.ddp import DDPManager
2425
from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager
2526
from nemo_automodel.shared.utils import dtype_from_str
2627

@@ -29,6 +30,9 @@
2930

3031
logger = logging.getLogger(__name__)
3132

33+
# Type alias for parallel managers
34+
ParallelManager = Union[FSDP2Manager, DDPManager]
35+
3236

3337
def _init_parallelizer():
3438
parallelizer.PARALLELIZATION_STRATEGIES["WanTransformer3DModel"] = WanParallelizationStrategy()
@@ -94,17 +98,52 @@ def _ensure_params_trainable(module: nn.Module, module_name: Optional[str] = Non
9498
return num_trainable_parameters
9599

96100

101+
def _create_parallel_manager(manager_args: Dict[str, Any]) -> ParallelManager:
102+
"""
103+
Factory function to create the appropriate parallel manager based on config.
104+
105+
The manager type is determined by the '_manager_type' key in manager_args:
106+
- 'ddp': Creates a DDPManager for standard Distributed Data Parallel
107+
- 'fsdp2' (default): Creates an FSDP2Manager for Fully Sharded Data Parallel
108+
109+
Args:
110+
manager_args: Dictionary of arguments for the manager. Must include '_manager_type'
111+
key to specify which manager to create. The '_manager_type' key is
112+
removed before passing args to the manager constructor.
113+
114+
Returns:
115+
Either an FSDP2Manager or DDPManager instance.
116+
117+
Raises:
118+
ValueError: If an unknown manager type is specified.
119+
"""
120+
# Make a copy to avoid modifying the original dict
121+
args = manager_args.copy()
122+
manager_type = args.pop("_manager_type", "fsdp2").lower()
123+
124+
if manager_type == "ddp":
125+
logger.info("[Parallel] Creating DDPManager with args: %s", args)
126+
return DDPManager(**args)
127+
elif manager_type == "fsdp2":
128+
logger.info("[Parallel] Creating FSDP2Manager with args: %s", args)
129+
return FSDP2Manager(**args)
130+
else:
131+
raise ValueError(f"Unknown manager type: '{manager_type}'. Expected 'ddp' or 'fsdp2'.")
132+
133+
97134
class NeMoAutoDiffusionPipeline(DiffusionPipeline):
98135
"""
99-
Drop-in Diffusers pipeline that adds optional FSDP2/TP parallelization during from_pretrained.
136+
Drop-in Diffusers pipeline that adds optional FSDP2/DDP parallelization during from_pretrained.
100137
101138
Features:
102-
- Accepts a per-component mapping from component name to FSDP2Manager init args
139+
- Accepts a per-component mapping from component name to parallel manager init args
103140
- Moves all nn.Module components to the chosen device/dtype
104141
- Parallelizes only components present in the mapping by constructing a manager per component
142+
- Supports both FSDP2Manager and DDPManager via '_manager_type' key in config
105143
106144
parallel_scheme:
107-
- Dict[str, Dict[str, Any]]: component name -> kwargs for FSDP2Manager(...)
145+
- Dict[str, Dict[str, Any]]: component name -> kwargs for parallel manager
146+
- Each component's kwargs should include '_manager_type': 'fsdp2' or 'ddp' (defaults to 'fsdp2')
108147
"""
109148

110149
@classmethod
@@ -119,7 +158,7 @@ def from_pretrained(
119158
load_for_training: bool = False,
120159
components_to_load: Optional[Iterable[str]] = None,
121160
**kwargs,
122-
) -> tuple[DiffusionPipeline, Dict[str, FSDP2Manager]]:
161+
) -> tuple[DiffusionPipeline, Dict[str, ParallelManager]]:
123162
pipe: DiffusionPipeline = DiffusionPipeline.from_pretrained(
124163
pretrained_model_name_or_path,
125164
*model_args,
@@ -143,16 +182,16 @@ def from_pretrained(
143182
logger.info("[INFO] Ensuring params trainable: %s", name)
144183
_ensure_params_trainable(module, module_name=name)
145184

146-
# Use per-component FSDP2Manager init-args to parallelize components
147-
created_managers: Dict[str, FSDP2Manager] = {}
185+
# Use per-component manager init-args to parallelize components
186+
created_managers: Dict[str, ParallelManager] = {}
148187
if parallel_scheme is not None:
149188
assert torch.distributed.is_initialized(), "Expect distributed environment to be initialized"
150189
_init_parallelizer()
151190
for comp_name, comp_module in _iter_pipeline_modules(pipe):
152191
manager_args = parallel_scheme.get(comp_name)
153192
if manager_args is None:
154193
continue
155-
manager = FSDP2Manager(**manager_args)
194+
manager = _create_parallel_manager(manager_args)
156195
created_managers[comp_name] = manager
157196
parallel_module = manager.parallelize(comp_module)
158197
setattr(pipe, comp_name, parallel_module)
@@ -177,7 +216,7 @@ def from_config(
177216
device: Optional[torch.device] = None,
178217
move_to_device: bool = True,
179218
components_to_load: Optional[Iterable[str]] = None,
180-
):
219+
) -> tuple[WanPipeline, Dict[str, ParallelManager]]:
181220
# Load just the config
182221
from diffusers import WanTransformer3DModel
183222

@@ -211,16 +250,16 @@ def from_config(
211250
logger.info("[INFO] Moving module: %s to device/dtype", name)
212251
_move_module_to_device(module, dev, torch_dtype)
213252

214-
# Use per-component FSDP2Manager init-args to parallelize components
215-
created_managers: Dict[str, FSDP2Manager] = {}
253+
# Use per-component manager init-args to parallelize components
254+
created_managers: Dict[str, ParallelManager] = {}
216255
if parallel_scheme is not None:
217256
assert torch.distributed.is_initialized(), "Expect distributed environment to be initialized"
218257
_init_parallelizer()
219258
for comp_name, comp_module in _iter_pipeline_modules(pipe):
220259
manager_args = parallel_scheme.get(comp_name)
221260
if manager_args is None:
222261
continue
223-
manager = FSDP2Manager(**manager_args)
262+
manager = _create_parallel_manager(manager_args)
224263
created_managers[comp_name] = manager
225264
parallel_module = manager.parallelize(comp_module)
226265
setattr(pipe, comp_name, parallel_module)

dfm/src/automodel/flow_matching/flow_matching_pipeline.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,11 +356,11 @@ def step(
356356
# ====================================================================
357357
# Logging
358358
# ====================================================================
359-
if detailed_log or debug_mode:
359+
if debug_mode and detailed_log:
360360
self._log_detailed(
361361
global_step, sampling_method, batch_size, sigma, timesteps, video_latents, noise, noisy_latents
362362
)
363-
elif summary_log:
363+
elif debug_mode and summary_log:
364364
logger.info(
365365
f"[STEP {global_step}] σ=[{sigma.min():.3f},{sigma.max():.3f}] | "
366366
f"t=[{timesteps.min():.1f},{timesteps.max():.1f}] | "
@@ -406,9 +406,9 @@ def step(
406406
raise ValueError(f"Loss exploded: {weighted_loss.item()}")
407407

408408
# Logging
409-
if detailed_log or debug_mode:
409+
if debug_mode and detailed_log:
410410
self._log_loss_detailed(global_step, model_pred, target, loss_weight, unweighted_loss, weighted_loss)
411-
elif summary_log:
411+
elif debug_mode and summary_log:
412412
logger.info(
413413
f"[STEP {global_step}] Loss: {weighted_loss.item():.6f} | "
414414
f"w=[{loss_weight.min():.2f},{loss_weight.max():.2f}]"

dfm/src/automodel/recipes/train.py

Lines changed: 107 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,37 @@ def build_model_and_optimizer(
4444
device: torch.device,
4545
dtype: torch.dtype,
4646
cpu_offload: bool = False,
47-
fsdp_cfg: Dict[str, Any] = {},
47+
fsdp_cfg: Optional[Dict[str, Any]] = None,
48+
ddp_cfg: Optional[Dict[str, Any]] = None,
4849
attention_backend: Optional[str] = None,
4950
optimizer_cfg: Optional[Dict[str, Any]] = None,
5051
) -> tuple[NeMoWanPipeline, dict[str, Dict[str, Any]], torch.optim.Optimizer, Any]:
51-
"""Build the diffusion model, parallel scheme, and optimizer."""
52+
"""Build the diffusion model, parallel scheme, and optimizer.
53+
54+
Args:
55+
model_id: Pretrained model name or path.
56+
finetune_mode: Whether to load for finetuning.
57+
learning_rate: Learning rate for optimizer.
58+
device: Target device.
59+
dtype: Model dtype.
60+
cpu_offload: Whether to enable CPU offload (FSDP only).
61+
fsdp_cfg: FSDP configuration dict. Mutually exclusive with ddp_cfg.
62+
ddp_cfg: DDP configuration dict. Mutually exclusive with fsdp_cfg.
63+
attention_backend: Optional attention backend override.
64+
optimizer_cfg: Optional optimizer configuration.
65+
66+
Returns:
67+
Tuple of (pipeline, optimizer, device_mesh or None).
68+
69+
Raises:
70+
ValueError: If both fsdp_cfg and ddp_cfg are provided.
71+
"""
72+
# Validate mutually exclusive configs
73+
if fsdp_cfg is not None and ddp_cfg is not None:
74+
raise ValueError(
75+
"Cannot specify both 'fsdp' and 'ddp' configurations. "
76+
"Please provide only one distributed training strategy."
77+
)
5278

5379
logging.info("[INFO] Building NeMoAutoDiffusionPipeline with transformer parallel scheme...")
5480

@@ -57,26 +83,42 @@ def build_model_and_optimizer(
5783

5884
world_size = dist.get_world_size() if dist.is_initialized() else 1
5985

60-
if fsdp_cfg.get("dp_size", None) is None:
61-
denom = max(1, fsdp_cfg.get("tp_size", 1) * fsdp_cfg.get("cp_size", 1) * fsdp_cfg.get("pp_size", 1))
62-
fsdp_cfg.dp_size = max(1, world_size // denom)
63-
64-
manager_args: Dict[str, Any] = {
65-
"dp_size": fsdp_cfg.get("dp_size", None),
66-
"dp_replicate_size": fsdp_cfg.get("dp_replicate_size", None),
67-
"tp_size": fsdp_cfg.get("tp_size", 1),
68-
"cp_size": fsdp_cfg.get("cp_size", 1),
69-
"pp_size": fsdp_cfg.get("pp_size", 1),
70-
"backend": "nccl",
71-
"world_size": world_size,
72-
"use_hf_tp_plan": fsdp_cfg.get("use_hf_tp_plan", False),
73-
"activation_checkpointing": True,
74-
"mp_policy": MixedPrecisionPolicy(
75-
param_dtype=dtype,
76-
reduce_dtype=torch.float32,
77-
output_dtype=dtype,
78-
),
79-
}
86+
# Build manager args based on which config is provided
87+
if ddp_cfg is not None:
88+
# DDP configuration
89+
logging.info("[INFO] Using DDP (DistributedDataParallel) for training")
90+
manager_args: Dict[str, Any] = {
91+
"_manager_type": "ddp",
92+
"backend": ddp_cfg.get("backend", "nccl"),
93+
"world_size": world_size,
94+
"activation_checkpointing": ddp_cfg.get("activation_checkpointing", False),
95+
}
96+
else:
97+
# FSDP configuration (default)
98+
fsdp_cfg = fsdp_cfg or {}
99+
logging.info("[INFO] Using FSDP2 (Fully Sharded Data Parallel) for training")
100+
101+
if fsdp_cfg.get("dp_size", None) is None:
102+
denom = max(1, fsdp_cfg.get("tp_size", 1) * fsdp_cfg.get("cp_size", 1) * fsdp_cfg.get("pp_size", 1))
103+
fsdp_cfg["dp_size"] = max(1, world_size // denom)
104+
105+
manager_args: Dict[str, Any] = {
106+
"_manager_type": "fsdp2",
107+
"dp_size": fsdp_cfg.get("dp_size", None),
108+
"dp_replicate_size": fsdp_cfg.get("dp_replicate_size", None),
109+
"tp_size": fsdp_cfg.get("tp_size", 1),
110+
"cp_size": fsdp_cfg.get("cp_size", 1),
111+
"pp_size": fsdp_cfg.get("pp_size", 1),
112+
"backend": "nccl",
113+
"world_size": world_size,
114+
"use_hf_tp_plan": fsdp_cfg.get("use_hf_tp_plan", False),
115+
"activation_checkpointing": fsdp_cfg.get("activation_checkpointing", True),
116+
"mp_policy": MixedPrecisionPolicy(
117+
param_dtype=dtype,
118+
reduce_dtype=torch.float32,
119+
output_dtype=dtype,
120+
),
121+
}
80122

81123
parallel_scheme = {"transformer": manager_args}
82124

@@ -194,10 +236,19 @@ def setup(self):
194236
logging.info(f"[INFO] Node rank: {self.node_rank}, Local rank: {self.local_rank}")
195237
logging.info(f"[INFO] Learning rate: {self.learning_rate}")
196238

197-
fsdp_cfg = self.cfg.get("fsdp", {})
239+
# Get distributed training configs (mutually exclusive)
240+
fsdp_cfg = self.cfg.get("fsdp", None)
241+
ddp_cfg = self.cfg.get("ddp", None)
198242
fm_cfg = self.cfg.get("flow_matching", {})
199243

200-
self.cpu_offload = fsdp_cfg.get("cpu_offload", False)
244+
# Validate mutually exclusive distributed configs
245+
if fsdp_cfg is not None and ddp_cfg is not None:
246+
raise ValueError(
247+
"Cannot specify both 'fsdp' and 'ddp' configurations in YAML. "
248+
"Please provide only one distributed training strategy."
249+
)
250+
251+
self.cpu_offload = fsdp_cfg.get("cpu_offload", False) if fsdp_cfg else False
201252

202253
# Flow matching configuration
203254
self.adapter_type = fm_cfg.get("adapter_type", "simple")
@@ -233,6 +284,7 @@ def setup(self):
233284
dtype=self.bf16,
234285
cpu_offload=self.cpu_offload,
235286
fsdp_cfg=fsdp_cfg,
287+
ddp_cfg=ddp_cfg,
236288
optimizer_cfg=self.cfg.get("optim.optimizer", {}),
237289
attention_backend=self.attention_backend,
238290
)
@@ -288,13 +340,19 @@ def setup(self):
288340
raise RuntimeError("Training dataloader is empty; cannot proceed with training")
289341

290342
# Derive DP size consistent with model parallel config
291-
tp_size = fsdp_cfg.get("tp_size", 1)
292-
cp_size = fsdp_cfg.get("cp_size", 1)
293-
pp_size = fsdp_cfg.get("pp_size", 1)
294-
denom = max(1, tp_size * cp_size * pp_size)
295-
self.dp_size = fsdp_cfg.get("dp_size", None)
296-
if self.dp_size is None:
297-
self.dp_size = max(1, self.world_size // denom)
343+
if ddp_cfg is not None:
344+
# DDP uses pure data parallelism across all ranks
345+
self.dp_size = self.world_size
346+
else:
347+
# FSDP may have TP/CP/PP dimensions
348+
_fsdp_cfg = fsdp_cfg or {}
349+
tp_size = _fsdp_cfg.get("tp_size", 1)
350+
cp_size = _fsdp_cfg.get("cp_size", 1)
351+
pp_size = _fsdp_cfg.get("pp_size", 1)
352+
denom = max(1, tp_size * cp_size * pp_size)
353+
self.dp_size = _fsdp_cfg.get("dp_size", None)
354+
if self.dp_size is None:
355+
self.dp_size = max(1, self.world_size // denom)
298356

299357
# Infer local micro-batch size from dataloader if available
300358
self.local_batch_size = self.cfg.step_scheduler.local_batch_size
@@ -449,3 +507,21 @@ def run_train_validation_loop(self):
449507
wandb.finish()
450508

451509
logging.info("[INFO] Training complete!")
510+
511+
def _get_dp_rank(self, include_cp: bool = False) -> int:
512+
"""Get data parallel rank, handling DDP mode where device_mesh is None."""
513+
# In DDP mode, device_mesh is None, so use torch.distributed directly
514+
device_mesh = getattr(self, "device_mesh", None)
515+
if device_mesh is None:
516+
return dist.get_rank() if dist.is_initialized() else 0
517+
# Otherwise, use the parent implementation
518+
return super()._get_dp_rank(include_cp=include_cp)
519+
520+
def _get_dp_group_size(self, include_cp: bool = False) -> int:
521+
"""Get data parallel world size, handling DDP mode where device_mesh is None."""
522+
# In DDP mode, device_mesh is None, so use torch.distributed directly
523+
device_mesh = getattr(self, "device_mesh", None)
524+
if device_mesh is None:
525+
return dist.get_world_size() if dist.is_initialized() else 1
526+
# Otherwise, use the parent implementation
527+
return super()._get_dp_group_size(include_cp=include_cp)

examples/automodel/finetune/hunyuan_t2v_flow.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ optim:
2020

2121
# FSDP (Fully Sharded Data Parallel) configuration
2222
fsdp:
23-
enable_fsdp: true
2423
dp_size: 8 # Auto-calculate based on world_size and other parallel dimensions
2524
dp_replicate_size: 1
2625
tp_size: 1 # Tensor parallelism size

examples/automodel/pretrain/wan2_1_t2v_flow.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ step_scheduler:
2222

2323
data:
2424
dataloader:
25-
_target_: dfm.src.automodel.datasets.build_wan21_dataloader
25+
_target_: dfm.src.automodel.datasets.build_dataloader
2626
meta_folder: /lustre/fsw/portfolios/coreai/users/linnanw/hdvilla_sample/pika/wan21_codes/1.3B_meta/
2727
num_workers: 2
2828
device: cpu

0 commit comments

Comments
 (0)