Skip to content

Commit a9a000f

Browse files
[bugfix] fix bz >1 for training (#477)
1 parent 66b8b85 commit a9a000f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

fastvideo/v1/training/wan_training_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def train_one_step(
147147
dtype=torch.bfloat16)
148148
with set_forward_context(current_timestep=timesteps,
149149
attn_metadata=None):
150-
model_pred = transformer(**input_kwargs)[0]
150+
model_pred = transformer(**input_kwargs)
151151

152152
if precondition_outputs:
153153
model_pred = noisy_model_input - model_pred * sigmas

0 commit comments

Comments
 (0)