Skip to content

Commit 05faf32

Browse files
authored
SDXL text-to-image torch compatible (#6550)
* torch compatible * code quality fix * ruff style * ruff format
1 parent a080f0d commit 05faf32

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

examples/text_to_image/train_text_to_image_sdxl.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,12 @@
4444
from transformers import AutoTokenizer, PretrainedConfig
4545

4646
import diffusers
47-
from diffusers import (
48-
AutoencoderKL,
49-
DDPMScheduler,
50-
StableDiffusionXLPipeline,
51-
UNet2DConditionModel,
52-
)
47+
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
5348
from diffusers.optimization import get_scheduler
5449
from diffusers.training_utils import EMAModel, compute_snr
5550
from diffusers.utils import check_min_version, is_wandb_available
5651
from diffusers.utils.import_utils import is_xformers_available
52+
from diffusers.utils.torch_utils import is_compiled_module
5753

5854

5955
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -508,11 +504,12 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca
508504
prompt_embeds = text_encoder(
509505
text_input_ids.to(text_encoder.device),
510506
output_hidden_states=True,
507+
return_dict=False,
511508
)
512509

513510
# We are only ALWAYS interested in the pooled output of the final text encoder
514511
pooled_prompt_embeds = prompt_embeds[0]
515-
prompt_embeds = prompt_embeds.hidden_states[-2]
512+
prompt_embeds = prompt_embeds[-1][-2]
516513
bs_embed, seq_len, _ = prompt_embeds.shape
517514
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
518515
prompt_embeds_list.append(prompt_embeds)
@@ -955,6 +952,12 @@ def collate_fn(examples):
955952
if accelerator.is_main_process:
956953
accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args))
957954

955+
# Function for unwraping if torch.compile() was used in accelerate.
956+
def unwrap_model(model):
957+
model = accelerator.unwrap_model(model)
958+
model = model._orig_mod if is_compiled_module(model) else model
959+
return model
960+
958961
# Train!
959962
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
960963

@@ -1054,8 +1057,12 @@ def compute_time_ids(original_size, crops_coords_top_left):
10541057
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
10551058
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
10561059
model_pred = unet(
1057-
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
1058-
).sample
1060+
noisy_model_input,
1061+
timesteps,
1062+
prompt_embeds,
1063+
added_cond_kwargs=unet_added_conditions,
1064+
return_dict=False,
1065+
)[0]
10591066

10601067
# Get the target for loss depending on the prediction type
10611068
if args.prediction_type is not None:
@@ -1206,7 +1213,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
12061213

12071214
accelerator.wait_for_everyone()
12081215
if accelerator.is_main_process:
1209-
unet = accelerator.unwrap_model(unet)
1216+
unet = unwrap_model(unet)
12101217
if args.use_ema:
12111218
ema_unet.copy_to(unet.parameters())
12121219

0 commit comments

Comments
 (0)