2222import torch
2323import torch .distributed as dist
2424import wandb
25- from Automodel ._diffusers .auto_diffusion_pipeline import NeMoAutoDiffusionPipeline
25+ from Automodel ._diffusers .auto_diffusion_pipeline import NeMoWanPipeline
2626from Automodel .flow_matching .training_step_t2v import (
2727 step_fsdp_transformer_t2v ,
2828)
@@ -51,10 +51,10 @@ def build_model_and_optimizer(
5151 dp_replicate_size : Optional [int ] = None ,
5252 use_hf_tp_plan : bool = False ,
5353 optimizer_cfg : Optional [Dict [str , Any ]] = None ,
54- ) -> tuple [NeMoAutoDiffusionPipeline , dict [str , Dict [str , Any ]], torch .optim .Optimizer , Any ]:
54+ ) -> tuple [NeMoWanPipeline , dict [str , Dict [str , Any ]], torch .optim .Optimizer , Any ]:
5555 """Build the WAN 2.1 diffusion model, parallel scheme, and optimizer."""
5656
57- logging .info ("[INFO] Building NeMoAutoDiffusionPipeline with transformer parallel scheme..." )
57+ logging .info ("[INFO] Building NeMoWanPipeline with transformer parallel scheme..." )
5858
5959 if not dist .is_initialized ():
6060 logging .info ("[WARN] torch.distributed not initialized; proceeding in single-process mode" )
@@ -84,7 +84,7 @@ def build_model_and_optimizer(
8484
8585 parallel_scheme = {"transformer" : manager_args }
8686
87- pipe , created_managers = NeMoAutoDiffusionPipeline .from_pretrained (
87+ pipe , created_managers = NeMoWanPipeline .from_pretrained (
8888 model_id ,
8989 torch_dtype = bf16_dtype ,
9090 device = device ,
@@ -93,11 +93,7 @@ def build_model_and_optimizer(
9393 components_to_load = ["transformer" ],
9494 )
9595 fsdp2_manager = created_managers ["transformer" ]
96- transformer_module = getattr (pipe , "transformer" , None )
97- if transformer_module is None :
98- raise RuntimeError ("transformer not found in pipeline after parallelization" )
99-
100- model_map : dict [str , Dict [str , Any ]] = {"transformer" : {"fsdp_transformer" : transformer_module }}
96+ transformer_module = pipe .transformer
10197
10298 trainable_params = [p for p in transformer_module .parameters () if p .requires_grad ]
10399 if not trainable_params :
@@ -121,7 +117,7 @@ def build_model_and_optimizer(
121117
122118 logging .info ("[INFO] NeMoAutoDiffusion setup complete (pipeline + optimizer)" )
123119
124- return pipe , model_map , optimizer , fsdp2_manager .device_mesh
120+ return pipe , optimizer , fsdp2_manager .device_mesh
125121
126122
127123def build_lr_scheduler (
@@ -214,7 +210,7 @@ def setup(self):
214210 dp_replicate_size = fsdp_cfg .get ("dp_replicate_size" , None )
215211 use_hf_tp_plan = fsdp_cfg .get ("use_hf_tp_plan" , False )
216212
217- (self .pipe , self .model_map , self . optimizer , self .device_mesh ) = build_model_and_optimizer (
213+ (self .pipe , self .optimizer , self .device_mesh ) = build_model_and_optimizer (
218214 model_id = self .model_id ,
219215 learning_rate = self .learning_rate ,
220216 device = self .device ,
@@ -229,7 +225,7 @@ def setup(self):
229225 optimizer_cfg = self .cfg .get ("optim.optimizer" , {}),
230226 )
231227
232- self .model = self .model_map [ " transformer" ][ "fsdp_transformer" ]
228+ self .model = self .pipe . transformer
233229 self .peft_config = None
234230
235231 batch_cfg = self .cfg .get ("batch" , {})
@@ -358,8 +354,8 @@ def run_train_validation_loop(self):
358354 for micro_batch in batch_group :
359355 try :
360356 loss , _ = step_fsdp_transformer_t2v (
361- pipe = self .pipe ,
362- model_map = self .model_map ,
357+ scheduler = self .pipe . scheduler ,
358+ model = self .model ,
363359 batch = micro_batch ,
364360 device = self .device ,
365361 bf16 = self .bf16 ,
0 commit comments