Skip to content

Commit b8215b1

Browse files
azolotenkovsayakpaulgithub-actions[bot]
authored
Fix incorrect seed initialization when args.seed is 0 (huggingface#10964)
* Fix seed initialization to handle args.seed = 0 correctly * Apply style fixes --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 3ee899f commit b8215b1

16 files changed

+32
-18
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def log_validation(
227227
pipeline.set_progress_bar_config(disable=True)
228228

229229
# run inference
230-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
230+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
231231
autocast_ctx = nullcontext()
232232

233233
with autocast_ctx:

examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1883,7 +1883,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18831883
pipeline.set_progress_bar_config(disable=True)
18841884

18851885
# run inference
1886-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1886+
generator = (
1887+
torch.Generator(device=accelerator.device).manual_seed(args.seed)
1888+
if args.seed is not None
1889+
else None
1890+
)
18871891
pipeline_args = {"prompt": args.validation_prompt}
18881892

18891893
if torch.backends.mps.is_available():
@@ -1987,7 +1991,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19871991
)
19881992
# run inference
19891993
pipeline = pipeline.to(accelerator.device)
1990-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1994+
generator = (
1995+
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
1996+
)
19911997
images = [
19921998
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
19931999
for _ in range(args.num_validation_images)

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def log_validation(
269269
pipeline.set_progress_bar_config(disable=True)
270270

271271
# run inference
272-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
272+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
273273
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
274274
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
275275
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:

examples/cogvideo/train_cogvideox_image_to_video_lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ def log_validation(
722722
# pipe.set_progress_bar_config(disable=True)
723723

724724
# run inference
725-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
725+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
726726

727727
videos = []
728728
for _ in range(args.num_validation_videos):

examples/cogvideo/train_cogvideox_lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ def log_validation(
739739
# pipe.set_progress_bar_config(disable=True)
740740

741741
# run inference
742-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
742+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
743743

744744
videos = []
745745
for _ in range(args.num_validation_videos):

examples/custom_diffusion/train_custom_diffusion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1334,7 +1334,9 @@ def main(args):
13341334

13351335
# run inference
13361336
if args.validation_prompt and args.num_validation_images > 0:
1337-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1337+
generator = (
1338+
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
1339+
)
13381340
images = [
13391341
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[0]
13401342
for _ in range(args.num_validation_images)

examples/dreambooth/train_dreambooth_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def log_validation(
172172
pipeline.set_progress_bar_config(disable=True)
173173

174174
# run inference
175-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
175+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
176176
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
177177
autocast_ctx = nullcontext()
178178

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def log_validation(
150150
pipeline.set_progress_bar_config(disable=True)
151151

152152
# run inference
153-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
153+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
154154

155155
if args.validation_images is None:
156156
images = []

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def log_validation(
181181
pipeline.set_progress_bar_config(disable=True)
182182

183183
# run inference
184-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
184+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
185185
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
186186
autocast_ctx = nullcontext()
187187

examples/dreambooth/train_dreambooth_lora_lumina2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def log_validation(
167167
pipeline.set_progress_bar_config(disable=True)
168168

169169
# run inference
170-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
170+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
171171
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
172172

173173
with autocast_ctx:

0 commit comments

Comments
 (0)