Skip to content

Commit c5b803c

Browse files
committed
rng state management: Implement functions to get and set RNG states for consistent validation
1 parent 45ec02b commit c5b803c

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

train_network.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,6 +1278,31 @@ def remove_model(old_ckpt_name):
12781278
original_args_min_timestep = args.min_timestep
12791279
original_args_max_timestep = args.max_timestep
12801280

1281+
def get_rng_state() -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]:
1282+
cpu_rng_state = torch.get_rng_state()
1283+
if accelerator.device.type == "cuda":
1284+
gpu_rng_state = torch.cuda.get_rng_state()
1285+
elif accelerator.device.type == "xpu":
1286+
gpu_rng_state = torch.xpu.get_rng_state()
1287+
elif accelerator.device.type == "mps":
1288+
gpu_rng_state = torch.cuda.get_rng_state()
1289+
else:
1290+
gpu_rng_state = None
1291+
python_rng_state = random.getstate()
1292+
return (cpu_rng_state, gpu_rng_state, python_rng_state)
1293+
1294+
def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]):
1295+
cpu_rng_state, gpu_rng_state, python_rng_state = rng_states
1296+
torch.set_rng_state(cpu_rng_state)
1297+
if gpu_rng_state is not None:
1298+
if accelerator.device.type == "cuda":
1299+
torch.cuda.set_rng_state(gpu_rng_state)
1300+
elif accelerator.device.type == "xpu":
1301+
torch.xpu.set_rng_state(gpu_rng_state)
1302+
elif accelerator.device.type == "mps":
1303+
torch.cuda.set_rng_state(gpu_rng_state)
1304+
random.setstate(python_rng_state)
1305+
12811306
for epoch in range(epoch_to_start, num_train_epochs):
12821307
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
12831308
current_epoch.value = epoch + 1
@@ -1391,7 +1416,7 @@ def remove_model(old_ckpt_name):
13911416
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
13921417
optimizer_eval_fn()
13931418
accelerator.unwrap_model(network).eval()
1394-
rng_state = torch.get_rng_state()
1419+
rng_states = get_rng_state()
13951420
torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed)
13961421

13971422
val_progress_bar = tqdm(
@@ -1453,7 +1478,7 @@ def remove_model(old_ckpt_name):
14531478
}
14541479
accelerator.log(logs, step=global_step)
14551480

1456-
torch.set_rng_state(rng_state)
1481+
set_rng_state(rng_states)
14571482
args.min_timestep = original_args_min_timestep
14581483
args.max_timestep = original_args_max_timestep
14591484
optimizer_train_fn()
@@ -1470,7 +1495,7 @@ def remove_model(old_ckpt_name):
14701495
if should_validate_epoch and len(val_dataloader) > 0:
14711496
optimizer_eval_fn()
14721497
accelerator.unwrap_model(network).eval()
1473-
rng_state = torch.get_rng_state()
1498+
rng_states = get_rng_state()
14741499
torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed)
14751500

14761501
val_progress_bar = tqdm(
@@ -1536,7 +1561,7 @@ def remove_model(old_ckpt_name):
15361561
}
15371562
accelerator.log(logs, step=global_step)
15381563

1539-
torch.set_rng_state(rng_state)
1564+
set_rng_state(rng_states)
15401565
args.min_timestep = original_args_min_timestep
15411566
args.max_timestep = original_args_max_timestep
15421567
optimizer_train_fn()

0 commit comments

Comments
 (0)