|
5 | 5 | from contextlib import nullcontext
|
6 | 6 | import torch
|
7 | 7 | from diffusers.utils import check_min_version
|
| 8 | +import random |
| 9 | +import numpy as np |
8 | 10 |
|
9 | 11 | from pipeline import LotusGPipeline, LotusDPipeline
|
10 | 12 | from utils.seed_all import seed_all
|
@@ -86,6 +88,11 @@ def parse_args():
|
86 | 88 | default="bilinear",
|
87 | 89 | help="Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`. Default: `bilinear`",
|
88 | 90 | )
|
| 91 | + parser.add_argument( |
| 92 | + "--rng_state_path", |
| 93 | + default=None, |
| 94 | + help="Load the random number generator states from the given path to ensure reproducibility of the results. " |
| 95 | + ) |
89 | 96 |
|
90 | 97 | args = parser.parse_args()
|
91 | 98 |
|
@@ -131,6 +138,15 @@ def main():
|
131 | 138 | logging.warning("CUDA is not available. Running on CPU will be slow.")
|
132 | 139 | logging.info(f"Device = {device}")
|
133 | 140 |
|
| 141 | + if args.rng_state_path: |
| 142 | + torch.cuda.synchronize() |
| 143 | + states = torch.load(args.rng_state_path) |
| 144 | + random.setstate(states["random_state"]) |
| 145 | + np.random.set_state(states["numpy_random_seed"]) |
| 146 | + torch.set_rng_state(states["torch_manual_seed"]) |
| 147 | + torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"][:1]) |
| 148 | + logging.info(f"Loading the RNG states from: {args.rng_state_path}") |
| 149 | + |
134 | 150 | # -------------------- Model --------------------
|
135 | 151 | if args.mode == 'generation':
|
136 | 152 | pipeline = LotusGPipeline.from_pretrained(
|
|
0 commit comments