Skip to content

Commit 45ec02b

Browse files
committed
use same noise for every validation
1 parent 42c0a9e commit 45ec02b

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

flux_train_network.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,6 @@ def get_noise_pred_and_target(
377377

378378
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
379379
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
380-
381380
with torch.set_grad_enabled(is_train), accelerator.autocast():
382381
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
383382
model_pred = unet(

train_network.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,6 +1391,8 @@ def remove_model(old_ckpt_name):
13911391
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
13921392
optimizer_eval_fn()
13931393
accelerator.unwrap_model(network).eval()
1394+
rng_state = torch.get_rng_state()
1395+
torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed)
13941396

13951397
val_progress_bar = tqdm(
13961398
range(validation_total_steps),
@@ -1451,6 +1453,7 @@ def remove_model(old_ckpt_name):
14511453
}
14521454
accelerator.log(logs, step=global_step)
14531455

1456+
torch.set_rng_state(rng_state)
14541457
args.min_timestep = original_args_min_timestep
14551458
args.max_timestep = original_args_max_timestep
14561459
optimizer_train_fn()
@@ -1467,6 +1470,8 @@ def remove_model(old_ckpt_name):
14671470
if should_validate_epoch and len(val_dataloader) > 0:
14681471
optimizer_eval_fn()
14691472
accelerator.unwrap_model(network).eval()
1473+
rng_state = torch.get_rng_state()
1474+
torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed)
14701475

14711476
val_progress_bar = tqdm(
14721477
range(validation_total_steps),
@@ -1531,6 +1536,7 @@ def remove_model(old_ckpt_name):
15311536
}
15321537
accelerator.log(logs, step=global_step)
15331538

1539+
torch.set_rng_state(rng_state)
15341540
args.min_timestep = original_args_min_timestep
15351541
args.max_timestep = original_args_max_timestep
15361542
optimizer_train_fn()

0 commit comments

Comments
 (0)