Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = nullcontext()

with autocast_ctx:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1883,7 +1883,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.seed is not None
else None
)
pipeline_args = {"prompt": args.validation_prompt}

if torch.backends.mps.is_available():
Expand Down Expand Up @@ -1987,7 +1991,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
)
# run inference
pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
)
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
for _ in range(args.num_validation_images)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
Expand Down
2 changes: 1 addition & 1 deletion examples/cogvideo/train_cogvideox_image_to_video_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def log_validation(
# pipe.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None

videos = []
for _ in range(args.num_validation_videos):
Expand Down
2 changes: 1 addition & 1 deletion examples/cogvideo/train_cogvideox_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def log_validation(
# pipe.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None

videos = []
for _ in range(args.num_validation_videos):
Expand Down
4 changes: 3 additions & 1 deletion examples/custom_diffusion/train_custom_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,7 +1334,9 @@ def main(args):

# run inference
if args.validation_prompt and args.num_validation_images > 0:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
)
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[0]
for _ in range(args.num_validation_images)
Expand Down
2 changes: 1 addition & 1 deletion examples/dreambooth/train_dreambooth_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()

Expand Down
2 changes: 1 addition & 1 deletion examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None

if args.validation_images is None:
images = []
Expand Down
2 changes: 1 addition & 1 deletion examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()

Expand Down
2 changes: 1 addition & 1 deletion examples/dreambooth/train_dreambooth_lora_lumina2.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()

with autocast_ctx:
Expand Down
2 changes: 1 addition & 1 deletion examples/dreambooth/train_dreambooth_lora_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None

images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]

Expand Down
2 changes: 1 addition & 1 deletion examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()

Expand Down
2 changes: 1 addition & 1 deletion examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
Expand Down
2 changes: 1 addition & 1 deletion examples/dreambooth/train_dreambooth_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()

Expand Down
2 changes: 1 addition & 1 deletion examples/text_to_image/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
pipeline_args = {"prompt": args.validation_prompt}
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
Expand Down
10 changes: 8 additions & 2 deletions examples/text_to_image/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,7 +1241,11 @@ def compute_time_ids(original_size, crops_coords_top_left):
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.seed is not None
else None
)
pipeline_args = {"prompt": args.validation_prompt}

with autocast_ctx:
Expand Down Expand Up @@ -1305,7 +1309,9 @@ def compute_time_ids(original_size, crops_coords_top_left):
images = []
if args.validation_prompt and args.num_validation_images > 0:
pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
)

with autocast_ctx:
images = [
Expand Down