Skip to content

Commit 5ecf0ed

Browse files
committed
initial commit
1 parent 0fa0993 commit 5ecf0ed

File tree

1 file changed

+82
-47
lines changed

1 file changed

+82
-47
lines changed

examples/dreambooth/train_dreambooth_lora_hidream.py

Lines changed: 82 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,11 +1028,6 @@ def encode_prompt(
10281028
):
10291029
prompt = [prompt] if isinstance(prompt, str) else prompt
10301030

1031-
if hasattr(text_encoders[0], "module"):
1032-
dtype = text_encoders[0].module.dtype
1033-
else:
1034-
dtype = text_encoders[0].dtype
1035-
10361031
pooled_prompt_embeds_1 = _encode_prompt_with_clip(
10371032
text_encoder=text_encoders[0],
10381033
tokenizer=tokenizers[0],
@@ -1179,21 +1174,50 @@ def main(args):
11791174
exist_ok=True,
11801175
).repo_id
11811176

1182-
# Load the tokenizer
1183-
tokenizer = AutoTokenizer.from_pretrained(
1177+
# Load the tokenizers
1178+
tokenizer_one = CLIPTokenizer.from_pretrained(
11841179
args.pretrained_model_name_or_path,
11851180
subfolder="tokenizer",
11861181
revision=args.revision,
11871182
)
1183+
tokenizer_two = CLIPTokenizer.from_pretrained(
1184+
args.pretrained_model_name_or_path,
1185+
subfolder="tokenizer_2",
1186+
revision=args.revision,
1187+
)
1188+
tokenizer_three = T5TokenizerFast.from_pretrained(
1189+
args.pretrained_model_name_or_path,
1190+
subfolder="tokenizer_3",
1191+
revision=args.revision,
1192+
)
1193+
1194+
tokenizer_four = PreTrainedTokenizerFast.from_pretrained(
1195+
args.pretrained_model_name_or_path,
1196+
subfolder="tokenizer_4",
1197+
revision=args.revision,
1198+
)
1199+
1200+
# import correct text encoder classes
1201+
text_encoder_cls_one = import_model_class_from_model_name_or_path(
1202+
args.pretrained_model_name_or_path, args.revision
1203+
)
1204+
text_encoder_cls_two = import_model_class_from_model_name_or_path(
1205+
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
1206+
)
1207+
text_encoder_cls_three = import_model_class_from_model_name_or_path(
1208+
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3"
1209+
)
1210+
text_encoder_cls_four = import_model_class_from_model_name_or_path(
1211+
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_4"
1212+
)
11881213

11891214
# Load scheduler and models
11901215
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
11911216
args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision
11921217
)
11931218
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
1194-
text_encoder = Gemma2Model.from_pretrained(
1195-
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
1196-
)
1219+
text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three, text_encoder_cls_four)
1220+
11971221
vae = AutoencoderKL.from_pretrained(
11981222
args.pretrained_model_name_or_path,
11991223
subfolder="vae",
@@ -1207,7 +1231,10 @@ def main(args):
12071231
# We only train the additional adapter LoRA layers
12081232
transformer.requires_grad_(False)
12091233
vae.requires_grad_(False)
1210-
text_encoder.requires_grad_(False)
1234+
text_encoder_one.requires_grad_(False)
1235+
text_encoder_two.requires_grad_(False)
1236+
text_encoder_three.requires_grad_(False)
1237+
text_encoder_four.requires_grad_(False)
12111238

12121239
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
12131240
# as these weights are only used for inference, keeping weights in full precision is not required.
@@ -1226,17 +1253,10 @@ def main(args):
12261253
# keep VAE in FP32 to ensure numerical stability.
12271254
vae.to(dtype=torch.float32)
12281255
transformer.to(accelerator.device, dtype=weight_dtype)
1229-
# because Gemma2 is particularly suited for bfloat16.
1230-
text_encoder.to(dtype=torch.bfloat16)
1231-
1232-
# Initialize a text encoding pipeline and keep it to CPU for now.
1233-
text_encoding_pipeline = HiDreamImagePipeline.from_pretrained(
1234-
args.pretrained_model_name_or_path,
1235-
vae=None,
1236-
transformer=None,
1237-
text_encoder=text_encoder,
1238-
tokenizer=tokenizer,
1239-
)
1256+
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
1257+
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
1258+
text_encoder_three.to(accelerator.device, dtype=weight_dtype)
1259+
text_encoder_four.to(accelerator.device, dtype=weight_dtype)
12401260

12411261
if args.gradient_checkpointing:
12421262
transformer.enable_gradient_checkpointing()
@@ -1417,47 +1437,45 @@ def load_model_hook(models, input_dir):
14171437
num_workers=args.dataloader_num_workers,
14181438
)
14191439

1420-
def compute_text_embeddings(prompt, text_encoding_pipeline):
1421-
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
1440+
tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three, tokenizer_four]
1441+
text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four]
1442+
def compute_text_embeddings(prompt, text_encoders, tokenizers):
14221443
with torch.no_grad():
1423-
prompt_embeds, prompt_attention_mask, _, _ = text_encoding_pipeline.encode_prompt(
1424-
prompt,
1425-
max_sequence_length=args.max_sequence_length,
1426-
system_prompt=args.system_prompt,
1444+
prompt_embeds, pooled_prompt_embeds = encode_prompt(
1445+
text_encoders, tokenizers, prompt, args.max_sequence_length
14271446
)
1428-
if args.offload:
1429-
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1430-
prompt_embeds = prompt_embeds.to(transformer.dtype)
1431-
return prompt_embeds, prompt_attention_mask
1447+
prompt_embeds = prompt_embeds.to(accelerator.device)
1448+
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
1449+
return prompt_embeds, pooled_prompt_embeds
14321450

14331451
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
14341452
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
14351453
# the redundant encoding.
14361454
if not train_dataset.custom_instance_prompts:
1437-
instance_prompt_hidden_states, instance_prompt_attention_mask = compute_text_embeddings(
1438-
args.instance_prompt, text_encoding_pipeline
1455+
instance_prompt_hidden_states, instance_pooled_prompt_embeds, = compute_text_embeddings(
1456+
args.instance_prompt, text_encoders, tokenizers
14391457
)
14401458

14411459
# Handle class prompt for prior-preservation.
14421460
if args.with_prior_preservation:
1443-
class_prompt_hidden_states, class_prompt_attention_mask = compute_text_embeddings(
1444-
args.class_prompt, text_encoding_pipeline
1461+
class_prompt_hidden_states, class_pooled_prompt_embeds, = compute_text_embeddings(
1462+
args.class_prompt, text_encoders, tokenizers
14451463
)
14461464

14471465
# Clear the memory here
14481466
if not train_dataset.custom_instance_prompts:
1449-
del text_encoder, tokenizer
1467+
del text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four, tokenizer_one, tokenizer_two,tokenizer_three, tokenizer_four
14501468
free_memory()
14511469

14521470
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
14531471
# pack the statically computed variables appropriately here. This is so that we don't
14541472
# have to pass them to the dataloader.
14551473
if not train_dataset.custom_instance_prompts:
14561474
prompt_embeds = instance_prompt_hidden_states
1457-
prompt_attention_mask = instance_prompt_attention_mask
1475+
pooled_prompt_embeds = instance_pooled_prompt_embeds
14581476
if args.with_prior_preservation:
14591477
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
1460-
prompt_attention_mask = torch.cat([prompt_attention_mask, class_prompt_attention_mask], dim=0)
1478+
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
14611479

14621480
vae_config_scaling_factor = vae.config.scaling_factor
14631481
vae_config_shift_factor = vae.config.shift_factor
@@ -1506,7 +1524,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
15061524
# We need to initialize the trackers we use, and also store our configuration.
15071525
# The trackers initializes automatically on the main process.
15081526
if accelerator.is_main_process:
1509-
tracker_name = "dreambooth-lumina2-lora"
1527+
tracker_name = "dreambooth-hidream-lora"
15101528
accelerator.init_trackers(tracker_name, config=vars(args))
15111529

15121530
# Train!
@@ -1580,7 +1598,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15801598
with accelerator.accumulate(models_to_accumulate):
15811599
# encode batch prompts when custom prompts are provided for each image -
15821600
if train_dataset.custom_instance_prompts:
1583-
prompt_embeds, prompt_attention_mask = compute_text_embeddings(prompts, text_encoding_pipeline)
1601+
prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(prompts, text_encoders, tokenizers)
15841602

15851603
# Convert images to latent space
15861604
if args.cache_latents:
@@ -1594,6 +1612,24 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15941612
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
15951613
model_input = model_input.to(dtype=weight_dtype)
15961614

1615+
if model_input.shape[-2] != model_input.shape[-1]:
1616+
B, C, H, W = model_input.shape
1617+
pH, pW = H // transformer.config.patch_size, W // transformer.config.patch_size
1618+
1619+
img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1)
1620+
img_ids = torch.zeros(pH, pW, 3)
1621+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None]
1622+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :]
1623+
img_ids = img_ids.reshape(pH * pW, -1)
1624+
img_ids_pad = torch.zeros(self.transformer.max_seq, 3)
1625+
img_ids_pad[: pH * pW, :] = img_ids
1626+
1627+
img_sizes = img_sizes.unsqueeze(0).to(model_input.device)
1628+
img_ids = img_ids_pad.unsqueeze(0).to(model_input.device)
1629+
1630+
else:
1631+
img_sizes = img_ids = None
1632+
15971633
# Sample noise that we'll add to the latents
15981634
noise = torch.randn_like(model_input)
15991635
bsz = model_input.shape[0]
@@ -1612,22 +1648,21 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16121648

16131649
# Add noise according to flow matching.
16141650
# zt = (1 - texp) * x + texp * z1
1615-
# Lumina2 reverses the lerp i.e., sigma of 1.0 should mean `model_input`
16161651
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
1617-
noisy_model_input = (1.0 - sigmas) * noise + sigmas * model_input
1652+
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
16181653

16191654
# Predict the noise residual
1620-
# scale the timesteps (reversal not needed as we used a reverse lerp above already)
1621-
timesteps = timesteps / noise_scheduler.config.num_train_timesteps
16221655
model_pred = transformer(
16231656
hidden_states=noisy_model_input,
16241657
encoder_hidden_states=prompt_embeds.repeat(len(prompts), 1, 1)
16251658
if not train_dataset.custom_instance_prompts
16261659
else prompt_embeds,
1627-
encoder_attention_mask=prompt_attention_mask.repeat(len(prompts), 1)
1660+
pooled_embeds=pooled_prompt_embeds.repeat(len(prompts), 1)
16281661
if not train_dataset.custom_instance_prompts
1629-
else prompt_attention_mask,
1662+
else pooled_prompt_embeds,
16301663
timestep=timesteps,
1664+
img_sizes=img_sizes,
1665+
img_ids=img_ids,
16311666
return_dict=False,
16321667
)[0]
16331668

0 commit comments

Comments
 (0)