Skip to content

Commit 42c0a9e

Browse files
committed
Merge branch 'sd3' into val-loss-improvement
2 parents 0750859 + 0778dd9 commit 42c0a9e

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

flux_train_network.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ def get_noise_pred_and_target(
377377

378378
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
379379
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
380+
380381
with torch.set_grad_enabled(is_train), accelerator.autocast():
381382
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
382383
model_pred = unet(

0 commit comments

Comments
 (0)