|
44 | 44 | from transformers import AutoTokenizer, PretrainedConfig
|
45 | 45 |
|
46 | 46 | import diffusers
|
47 |
| -from diffusers import ( |
48 |
| - AutoencoderKL, |
49 |
| - DDPMScheduler, |
50 |
| - StableDiffusionXLPipeline, |
51 |
| - UNet2DConditionModel, |
52 |
| -) |
| 47 | +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel |
53 | 48 | from diffusers.optimization import get_scheduler
|
54 | 49 | from diffusers.training_utils import EMAModel, compute_snr
|
55 | 50 | from diffusers.utils import check_min_version, is_wandb_available
|
56 | 51 | from diffusers.utils.import_utils import is_xformers_available
|
| 52 | +from diffusers.utils.torch_utils import is_compiled_module |
57 | 53 |
|
58 | 54 |
|
59 | 55 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
@@ -508,11 +504,12 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca
|
508 | 504 | prompt_embeds = text_encoder(
|
509 | 505 | text_input_ids.to(text_encoder.device),
|
510 | 506 | output_hidden_states=True,
|
| 507 | + return_dict=False, |
511 | 508 | )
|
512 | 509 |
|
513 | 510 | # We are only ALWAYS interested in the pooled output of the final text encoder
|
514 | 511 | pooled_prompt_embeds = prompt_embeds[0]
|
515 |
| - prompt_embeds = prompt_embeds.hidden_states[-2] |
| 512 | + prompt_embeds = prompt_embeds[-1][-2] |
516 | 513 | bs_embed, seq_len, _ = prompt_embeds.shape
|
517 | 514 | prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
518 | 515 | prompt_embeds_list.append(prompt_embeds)
|
@@ -955,6 +952,12 @@ def collate_fn(examples):
|
955 | 952 | if accelerator.is_main_process:
|
956 | 953 | accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args))
|
957 | 954 |
|
| 955 | + # Function for unwraping if torch.compile() was used in accelerate. |
| 956 | + def unwrap_model(model): |
| 957 | + model = accelerator.unwrap_model(model) |
| 958 | + model = model._orig_mod if is_compiled_module(model) else model |
| 959 | + return model |
| 960 | + |
958 | 961 | # Train!
|
959 | 962 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
960 | 963 |
|
@@ -1054,8 +1057,12 @@ def compute_time_ids(original_size, crops_coords_top_left):
|
1054 | 1057 | pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
|
1055 | 1058 | unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
|
1056 | 1059 | model_pred = unet(
|
1057 |
| - noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions |
1058 |
| - ).sample |
| 1060 | + noisy_model_input, |
| 1061 | + timesteps, |
| 1062 | + prompt_embeds, |
| 1063 | + added_cond_kwargs=unet_added_conditions, |
| 1064 | + return_dict=False, |
| 1065 | + )[0] |
1059 | 1066 |
|
1060 | 1067 | # Get the target for loss depending on the prediction type
|
1061 | 1068 | if args.prediction_type is not None:
|
@@ -1206,7 +1213,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
|
1206 | 1213 |
|
1207 | 1214 | accelerator.wait_for_everyone()
|
1208 | 1215 | if accelerator.is_main_process:
|
1209 |
| - unet = accelerator.unwrap_model(unet) |
| 1216 | + unet = unwrap_model(unet) |
1210 | 1217 | if args.use_ema:
|
1211 | 1218 | ema_unet.copy_to(unet.parameters())
|
1212 | 1219 |
|
|
0 commit comments