Skip to content

Commit 50f24c3

Browse files
committed
load RNG state for reproducibility
1 parent 7758dd9 commit 50f24c3

File tree

5 files changed

+16
-0
lines changed

5 files changed

+16
-0
lines changed

eval.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from contextlib import nullcontext
66
import torch
77
from diffusers.utils import check_min_version
8+
import random
9+
import numpy as np
810

911
from pipeline import LotusGPipeline, LotusDPipeline
1012
from utils.seed_all import seed_all
@@ -86,6 +88,11 @@ def parse_args():
8688
default="bilinear",
8789
help="Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`. Default: `bilinear`",
8890
)
91+
parser.add_argument(
92+
"--rng_state_path",
93+
default=None,
94+
help="Load the random number generator states from the given path to ensure reproducibility of the results. "
95+
)
8996

9097
args = parser.parse_args()
9198

@@ -131,6 +138,15 @@ def main():
131138
logging.warning("CUDA is not available. Running on CPU will be slow.")
132139
logging.info(f"Device = {device}")
133140

141+
if args.rng_state_path:
142+
torch.cuda.synchronize()
143+
states = torch.load(args.rng_state_path)
144+
random.setstate(states["random_state"])
145+
np.random.set_state(states["numpy_random_seed"])
146+
torch.set_rng_state(states["torch_manual_seed"])
147+
torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"][:1])
148+
logging.info(f"Loading the RNG states from: {args.rng_state_path}")
149+
134150
# -------------------- Model --------------------
135151
if args.mode == 'generation':
136152
pipeline = LotusGPipeline.from_pretrained(
14.8 KB
Binary file not shown.
14.7 KB
Binary file not shown.
14.7 KB
Binary file not shown.
14.8 KB
Binary file not shown.

0 commit comments

Comments
 (0)