Skip to content

Commit 320f92b

Browse files
committed
update param
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent b3702da commit 320f92b

File tree

1 file changed

+10
-14
lines changed

1 file changed

+10
-14
lines changed

dfm/src/Automodel/recipes/train.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch
2323
import torch.distributed as dist
2424
import wandb
25-
from Automodel._diffusers.auto_diffusion_pipeline import NeMoAutoDiffusionPipeline
25+
from Automodel._diffusers.auto_diffusion_pipeline import NeMoWanPipeline
2626
from Automodel.flow_matching.training_step_t2v import (
2727
step_fsdp_transformer_t2v,
2828
)
@@ -51,10 +51,10 @@ def build_model_and_optimizer(
5151
dp_replicate_size: Optional[int] = None,
5252
use_hf_tp_plan: bool = False,
5353
optimizer_cfg: Optional[Dict[str, Any]] = None,
54-
) -> tuple[NeMoAutoDiffusionPipeline, dict[str, Dict[str, Any]], torch.optim.Optimizer, Any]:
54+
) -> tuple[NeMoWanPipeline, dict[str, Dict[str, Any]], torch.optim.Optimizer, Any]:
5555
"""Build the WAN 2.1 diffusion model, parallel scheme, and optimizer."""
5656

57-
logging.info("[INFO] Building NeMoAutoDiffusionPipeline with transformer parallel scheme...")
57+
logging.info("[INFO] Building NeMoWanPipeline with transformer parallel scheme...")
5858

5959
if not dist.is_initialized():
6060
logging.info("[WARN] torch.distributed not initialized; proceeding in single-process mode")
@@ -84,7 +84,7 @@ def build_model_and_optimizer(
8484

8585
parallel_scheme = {"transformer": manager_args}
8686

87-
pipe, created_managers = NeMoAutoDiffusionPipeline.from_pretrained(
87+
pipe, created_managers = NeMoWanPipeline.from_pretrained(
8888
model_id,
8989
torch_dtype=bf16_dtype,
9090
device=device,
@@ -93,11 +93,7 @@ def build_model_and_optimizer(
9393
components_to_load=["transformer"],
9494
)
9595
fsdp2_manager = created_managers["transformer"]
96-
transformer_module = getattr(pipe, "transformer", None)
97-
if transformer_module is None:
98-
raise RuntimeError("transformer not found in pipeline after parallelization")
99-
100-
model_map: dict[str, Dict[str, Any]] = {"transformer": {"fsdp_transformer": transformer_module}}
96+
transformer_module = pipe.transformer
10197

10298
trainable_params = [p for p in transformer_module.parameters() if p.requires_grad]
10399
if not trainable_params:
@@ -121,7 +117,7 @@ def build_model_and_optimizer(
121117

122118
logging.info("[INFO] NeMoAutoDiffusion setup complete (pipeline + optimizer)")
123119

124-
return pipe, model_map, optimizer, fsdp2_manager.device_mesh
120+
return pipe, optimizer, fsdp2_manager.device_mesh
125121

126122

127123
def build_lr_scheduler(
@@ -214,7 +210,7 @@ def setup(self):
214210
dp_replicate_size = fsdp_cfg.get("dp_replicate_size", None)
215211
use_hf_tp_plan = fsdp_cfg.get("use_hf_tp_plan", False)
216212

217-
(self.pipe, self.model_map, self.optimizer, self.device_mesh) = build_model_and_optimizer(
213+
(self.pipe, self.optimizer, self.device_mesh) = build_model_and_optimizer(
218214
model_id=self.model_id,
219215
learning_rate=self.learning_rate,
220216
device=self.device,
@@ -229,7 +225,7 @@ def setup(self):
229225
optimizer_cfg=self.cfg.get("optim.optimizer", {}),
230226
)
231227

232-
self.model = self.model_map["transformer"]["fsdp_transformer"]
228+
self.model = self.pipe.transformer
233229
self.peft_config = None
234230

235231
batch_cfg = self.cfg.get("batch", {})
@@ -358,8 +354,8 @@ def run_train_validation_loop(self):
358354
for micro_batch in batch_group:
359355
try:
360356
loss, _ = step_fsdp_transformer_t2v(
361-
pipe=self.pipe,
362-
model_map=self.model_map,
357+
scheduler=self.pipe.scheduler,
358+
model=self.model,
363359
batch=micro_batch,
364360
device=self.device,
365361
bf16=self.bf16,

0 commit comments

Comments
 (0)