|
31 | 31 | from diffusers.utils import is_wandb_available |
32 | 32 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card |
33 | 33 |
|
34 | | -# torch._dynamo.config.force_parameter_static_shapes = False |
35 | | - |
36 | 34 | if is_wandb_available(): |
37 | 35 | pass |
38 | 36 |
|
@@ -635,7 +633,6 @@ def main(args): |
635 | 633 | text_encoder_2 = text_encoder_2.to(device, dtype=weight_dtype) |
636 | 634 | vae = vae.to(device, dtype=weight_dtype) |
637 | 635 | unet = unet.to(device, dtype=weight_dtype) |
638 | | - #unet = torch.compile(unet, backend='openxla', dynamic=True) |
639 | 636 | optimizer = setup_optimizer(unet, args) |
640 | 637 | vae.requires_grad_(False) |
641 | 638 | text_encoder.requires_grad_(False) |
@@ -723,14 +720,6 @@ def collate_fn(examples): |
723 | 720 | crop_top_lefts = torch.stack([torch.tensor(example["crop_top_lefts"]) for example in examples]) |
724 | 721 | prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples]).to(dtype=weight_dtype) |
725 | 722 | pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples]).to(dtype=weight_dtype) |
726 | | - # print("model_input.shape: ", model_input.shape) |
727 | | - # print("model_input.dtype: ", model_input.dtype) |
728 | | - # print("prompt_embeds.shape: ", prompt_embeds.shape) |
729 | | - # print("prompt_embeds.dtype: ", prompt_embeds.dtype) |
730 | | - # print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape) |
731 | | - # print("pooled_prompt_embeds.dtype: ", pooled_prompt_embeds.dtype) |
732 | | - # print("original_sizes.shape: ", original_sizes.shape) |
733 | | - # print("crop_top_lefts.shape: ", crop_top_lefts.shape) |
734 | 723 | return { |
735 | 724 | "model_input": model_input, |
736 | 725 | "prompt_embeds": prompt_embeds, |
|
0 commit comments