Skip to content

Commit b3702da

Browse files
committed
pass scheduler and model
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent a5b94b9 commit b3702da

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

dfm/src/Automodel/flow_matching/training_step_t2v.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828

2929

3030
def step_fsdp_transformer_t2v(
31-
pipe,
32-
model_map: Dict,
31+
scheduler,
32+
model,
3333
batch,
3434
device,
3535
bf16,
@@ -76,7 +76,7 @@ def step_fsdp_transformer_t2v(
7676
# Flow Matching Timestep Sampling
7777
# ========================================================================
7878

79-
num_train_timesteps = pipe.scheduler.config.num_train_timesteps
79+
num_train_timesteps = scheduler.config.num_train_timesteps
8080

8181
if use_sigma_noise:
8282
use_uniform = torch.rand(1).item() < mix_uniform_ratio
@@ -117,7 +117,6 @@ def step_fsdp_transformer_t2v(
117117
sigma = u
118118
sampling_method = "uniform_no_shift"
119119

120-
121120
# ========================================================================
122121
# Manual Flow Matching Noise Addition
123122
# ========================================================================
@@ -200,10 +199,8 @@ def step_fsdp_transformer_t2v(
200199
# Forward Pass
201200
# ========================================================================
202201

203-
fsdp_model = model_map["transformer"]["fsdp_transformer"]
204-
205202
try:
206-
model_pred = fsdp_model(
203+
model_pred = model(
207204
hidden_states=noisy_latents,
208205
timestep=timesteps_for_model,
209206
encoder_hidden_states=text_embeddings,
@@ -257,7 +254,7 @@ def step_fsdp_transformer_t2v(
257254
logger.info(f"[STEP {global_step}] LOSS DEBUG")
258255
logger.info("=" * 80)
259256
logger.info("[TARGET] Flow matching: v = ε - x_0")
260-
logger.info(f"[PREDICTION] Scheduler type (inference only): {type(pipe.scheduler).__name__}")
257+
logger.info(f"[PREDICTION] Scheduler type (inference only): {type(scheduler).__name__}")
261258
logger.info("")
262259
logger.info(f"[RANGES] Model pred: [{model_pred.min():.4f}, {model_pred.max():.4f}]")
263260
logger.info(f"[RANGES] Target (v): [{target.min():.4f}, {target.max():.4f}]")

0 commit comments

Comments
 (0)