|
28 | 28 |
|
29 | 29 |
|
30 | 30 | def step_fsdp_transformer_t2v( |
31 | | - pipe, |
32 | | - model_map: Dict, |
| 31 | + scheduler, |
| 32 | + model, |
33 | 33 | batch, |
34 | 34 | device, |
35 | 35 | bf16, |
@@ -76,7 +76,7 @@ def step_fsdp_transformer_t2v( |
76 | 76 | # Flow Matching Timestep Sampling |
77 | 77 | # ======================================================================== |
78 | 78 |
|
79 | | - num_train_timesteps = pipe.scheduler.config.num_train_timesteps |
| 79 | + num_train_timesteps = scheduler.config.num_train_timesteps |
80 | 80 |
|
81 | 81 | if use_sigma_noise: |
82 | 82 | use_uniform = torch.rand(1).item() < mix_uniform_ratio |
@@ -117,7 +117,6 @@ def step_fsdp_transformer_t2v( |
117 | 117 | sigma = u |
118 | 118 | sampling_method = "uniform_no_shift" |
119 | 119 |
|
120 | | - |
121 | 120 | # ======================================================================== |
122 | 121 | # Manual Flow Matching Noise Addition |
123 | 122 | # ======================================================================== |
@@ -200,10 +199,8 @@ def step_fsdp_transformer_t2v( |
200 | 199 | # Forward Pass |
201 | 200 | # ======================================================================== |
202 | 201 |
|
203 | | - fsdp_model = model_map["transformer"]["fsdp_transformer"] |
204 | | - |
205 | 202 | try: |
206 | | - model_pred = fsdp_model( |
| 203 | + model_pred = model( |
207 | 204 | hidden_states=noisy_latents, |
208 | 205 | timestep=timesteps_for_model, |
209 | 206 | encoder_hidden_states=text_embeddings, |
@@ -257,7 +254,7 @@ def step_fsdp_transformer_t2v( |
257 | 254 | logger.info(f"[STEP {global_step}] LOSS DEBUG") |
258 | 255 | logger.info("=" * 80) |
259 | 256 | logger.info("[TARGET] Flow matching: v = ε - x_0") |
260 | | - logger.info(f"[PREDICTION] Scheduler type (inference only): {type(pipe.scheduler).__name__}") |
| 257 | + logger.info(f"[PREDICTION] Scheduler type (inference only): {type(scheduler).__name__}") |
261 | 258 | logger.info("") |
262 | 259 | logger.info(f"[RANGES] Model pred: [{model_pred.min():.4f}, {model_pred.max():.4f}]") |
263 | 260 | logger.info(f"[RANGES] Target (v): [{target.min():.4f}, {target.max():.4f}]") |
|
0 commit comments