@@ -1278,6 +1278,31 @@ def remove_model(old_ckpt_name):
12781278 original_args_min_timestep = args .min_timestep
12791279 original_args_max_timestep = args .max_timestep
12801280
1281+ def get_rng_state () -> tuple [torch .ByteTensor , Optional [torch .ByteTensor ], tuple ]:
1282+ cpu_rng_state = torch .get_rng_state ()
1283+ if accelerator .device .type == "cuda" :
1284+ gpu_rng_state = torch .cuda .get_rng_state ()
1285+ elif accelerator .device .type == "xpu" :
1286+ gpu_rng_state = torch .xpu .get_rng_state ()
1287+ elif accelerator .device .type == "mps" :
1288+ gpu_rng_state = torch .cuda .get_rng_state ()
1289+ else :
1290+ gpu_rng_state = None
1291+ python_rng_state = random .getstate ()
1292+ return (cpu_rng_state , gpu_rng_state , python_rng_state )
1293+
1294+ def set_rng_state (rng_states : tuple [torch .ByteTensor , Optional [torch .ByteTensor ], tuple ]):
1295+ cpu_rng_state , gpu_rng_state , python_rng_state = rng_states
1296+ torch .set_rng_state (cpu_rng_state )
1297+ if gpu_rng_state is not None :
1298+ if accelerator .device .type == "cuda" :
1299+ torch .cuda .set_rng_state (gpu_rng_state )
1300+ elif accelerator .device .type == "xpu" :
1301+ torch .xpu .set_rng_state (gpu_rng_state )
1302+ elif accelerator .device .type == "mps" :
1303+ torch .cuda .set_rng_state (gpu_rng_state )
1304+ random .setstate (python_rng_state )
1305+
12811306 for epoch in range (epoch_to_start , num_train_epochs ):
12821307 accelerator .print (f"\n epoch { epoch + 1 } /{ num_train_epochs } \n " )
12831308 current_epoch .value = epoch + 1
@@ -1391,7 +1416,7 @@ def remove_model(old_ckpt_name):
13911416 if accelerator .sync_gradients and validation_steps > 0 and should_validate_step :
13921417 optimizer_eval_fn ()
13931418 accelerator .unwrap_model (network ).eval ()
1394- rng_state = torch . get_rng_state ()
1419+ rng_states = get_rng_state ()
13951420 torch .manual_seed (args .validation_seed if args .validation_seed is not None else args .seed )
13961421
13971422 val_progress_bar = tqdm (
@@ -1453,7 +1478,7 @@ def remove_model(old_ckpt_name):
14531478 }
14541479 accelerator .log (logs , step = global_step )
14551480
1456- torch . set_rng_state (rng_state )
1481+ set_rng_state (rng_states )
14571482 args .min_timestep = original_args_min_timestep
14581483 args .max_timestep = original_args_max_timestep
14591484 optimizer_train_fn ()
@@ -1470,7 +1495,7 @@ def remove_model(old_ckpt_name):
14701495 if should_validate_epoch and len (val_dataloader ) > 0 :
14711496 optimizer_eval_fn ()
14721497 accelerator .unwrap_model (network ).eval ()
1473- rng_state = torch . get_rng_state ()
1498+ rng_states = get_rng_state ()
14741499 torch .manual_seed (args .validation_seed if args .validation_seed is not None else args .seed )
14751500
14761501 val_progress_bar = tqdm (
@@ -1536,7 +1561,7 @@ def remove_model(old_ckpt_name):
15361561 }
15371562 accelerator .log (logs , step = global_step )
15381563
1539- torch . set_rng_state (rng_state )
1564+ set_rng_state (rng_states )
15401565 args .min_timestep = original_args_min_timestep
15411566 args .max_timestep = original_args_max_timestep
15421567 optimizer_train_fn ()
0 commit comments