Skip to content

Commit 9274126

Browse files
committed
change code to use PEFT as discussed in issue 10062
1 parent 2b9bc1d commit 9274126

File tree

1 file changed

+122
-47
lines changed

1 file changed

+122
-47
lines changed

examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py

Lines changed: 122 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
from datasets import load_dataset
4444
from huggingface_hub import create_repo, upload_folder
4545
from packaging import version
46+
from peft import LoraConfig
47+
from peft.utils import get_peft_model_state_dict
4648
from torchvision import transforms
4749
from tqdm.auto import tqdm
4850
from transformers import CLIPTextModel, CLIPTokenizer
@@ -51,9 +53,12 @@
5153
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel
5254
from diffusers.models.lora import LoRALinearLayer
5355
from diffusers.optimization import get_scheduler
54-
from diffusers.training_utils import EMAModel
55-
from diffusers.utils import check_min_version, deprecate, is_wandb_available
56+
from diffusers.training_utils import cast_training_params, EMAModel
57+
from diffusers.utils import check_min_version, deprecate, convert_state_dict_to_diffusers, is_wandb_available
58+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
5659
from diffusers.utils.import_utils import is_xformers_available
60+
from diffusers.utils.torch_utils import is_compiled_module
61+
5762
if is_wandb_available():
5863

5964
import wandb
@@ -69,6 +74,47 @@
6974
}
7075
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
7176

77+
def save_model_card(
78+
repo_id: str,
79+
images: list = None,
80+
base_model: str = None,
81+
dataset_name: str = None,
82+
repo_folder: str = None,
83+
):
84+
img_str = ""
85+
if images is not None:
86+
for i, image in enumerate(images):
87+
image.save(os.path.join(repo_folder, f"image_{i}.png"))
88+
img_str += f"![img_{i}](./image_{i}.png)\n"
89+
90+
model_description = f"""
91+
# LoRA text2image fine-tuning - {repo_id}
92+
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
93+
{img_str}
94+
"""
95+
96+
model_card = load_or_create_model_card(
97+
repo_id_or_path=repo_id,
98+
from_training=True,
99+
license="creativeml-openrail-m",
100+
base_model=base_model,
101+
model_description=model_description,
102+
inference=True,
103+
)
104+
105+
tags = [
106+
"stable-diffusion",
107+
"stable-diffusion-diffusers",
108+
"text-to-image",
109+
"diffusers",
110+
"diffusers-training",
111+
"lora",
112+
]
113+
model_card = populate_model_card(model_card, tags=tags)
114+
115+
model_card.save(os.path.join(repo_folder, "README.md"))
116+
117+
72118
def log_validation(
73119
pipeline,
74120
args,
@@ -535,43 +581,35 @@ def main():
535581
unet.requires_grad_(False)
536582

537583
# referred to https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py
538-
unet_lora_parameters = []
539-
for attn_processor_name, attn_processor in unet.attn_processors.items():
540-
# Parse the attention module.
541-
attn_module = unet
542-
for n in attn_processor_name.split(".")[:-1]:
543-
attn_module = getattr(attn_module, n)
544-
545-
# Set the `lora_layer` attribute of the attention-related matrices.
546-
attn_module.to_q.set_lora_layer(
547-
LoRALinearLayer(
548-
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
549-
)
550-
)
551-
attn_module.to_k.set_lora_layer(
552-
LoRALinearLayer(
553-
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
554-
)
555-
)
584+
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
585+
# as these weights are only used for inference, keeping weights in full precision is not required.
586+
weight_dtype = torch.float32
587+
if accelerator.mixed_precision == "fp16":
588+
weight_dtype = torch.float16
589+
elif accelerator.mixed_precision == "bf16":
590+
weight_dtype = torch.bfloat16
556591

557-
attn_module.to_v.set_lora_layer(
558-
LoRALinearLayer(
559-
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
560-
)
561-
)
562-
attn_module.to_out[0].set_lora_layer(
563-
LoRALinearLayer(
564-
in_features=attn_module.to_out[0].in_features,
565-
out_features=attn_module.to_out[0].out_features,
566-
rank=args.rank,
567-
)
568-
)
592+
# Freeze the unet parameters before adding adapters
593+
for param in unet.parameters():
594+
param.requires_grad_(False)
569595

570-
# Accumulate the LoRA params to optimize.
571-
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
572-
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
573-
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
574-
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
596+
unet_lora_config = LoraConfig(
597+
r=args.rank,
598+
lora_alpha=args.rank,
599+
init_lora_weights="gaussian",
600+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
601+
)
602+
603+
# Move unet, vae and text_encoder to device and cast to weight_dtype
604+
unet.to(accelerator.device, dtype=weight_dtype)
605+
vae.to(accelerator.device, dtype=weight_dtype)
606+
text_encoder.to(accelerator.device, dtype=weight_dtype)
607+
608+
# Add adapter and make sure the trainable params are in float32.
609+
unet.add_adapter(unet_lora_config)
610+
if args.mixed_precision == "fp16":
611+
# only upcast trainable parameters (LoRA) into fp32
612+
cast_training_params(unet, dtype=torch.float32)
575613

576614
# Create EMA for the unet.
577615
if args.use_ema:
@@ -590,6 +628,8 @@ def main():
590628
else:
591629
raise ValueError("xformers is not available. Make sure it is installed correctly")
592630

631+
lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
632+
593633
def unwrap_model(model):
594634
model = accelerator.unwrap_model(model)
595635
model = model._orig_mod if is_compiled_module(model) else model
@@ -657,9 +697,9 @@ def load_model_hook(models, input_dir):
657697
else:
658698
optimizer_cls = torch.optim.AdamW
659699

660-
# train on only unet_lora_parameters
700+
# train on only lora_layers
661701
optimizer = optimizer_cls(
662-
unet_lora_parameters,
702+
lora_layers,
663703
lr=args.learning_rate,
664704
betas=(args.adam_beta1, args.adam_beta2),
665705
weight_decay=args.adam_weight_decay,
@@ -817,8 +857,8 @@ def collate_fn(examples):
817857
)
818858

819859
# Prepare everything with our `accelerator`.
820-
unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
821-
unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
860+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
861+
unet, optimizer, train_dataloader, lr_scheduler
822862
)
823863

824864
if args.use_ema:
@@ -964,7 +1004,7 @@ def collate_fn(examples):
9641004
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
9651005

9661006
# Predict the noise residual and compute loss
967-
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample
1007+
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
9681008
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
9691009

9701010
# Gather the losses across all processes for logging (if we use distributed training).
@@ -974,15 +1014,15 @@ def collate_fn(examples):
9741014
# Backpropagate
9751015
accelerator.backward(loss)
9761016
if accelerator.sync_gradients:
977-
accelerator.clip_grad_norm_(unet_lora_parameters, args.max_grad_norm)
1017+
accelerator.clip_grad_norm_(lora_layers, args.max_grad_norm)
9781018
optimizer.step()
9791019
lr_scheduler.step()
9801020
optimizer.zero_grad()
9811021

9821022
# Checks if the accelerator has performed an optimization step behind the scenes
9831023
if accelerator.sync_gradients:
9841024
if args.use_ema:
985-
ema_unet.step(unet_lora_parameters)
1025+
ema_unet.step(lora_layers)
9861026
progress_bar.update(1)
9871027
global_step += 1
9881028
accelerator.log({"train_loss": train_loss}, step=global_step)
@@ -1012,6 +1052,16 @@ def collate_fn(examples):
10121052

10131053
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
10141054
accelerator.save_state(save_path)
1055+
unwrapped_unet = unwrap_model(unet)
1056+
unet_lora_state_dict = convert_state_dict_to_diffusers(
1057+
get_peft_model_state_dict(unwrapped_unet)
1058+
)
1059+
1060+
StableDiffusionInstructPix2PixPipeline.save_lora_weights(
1061+
save_directory=save_path,
1062+
unet_lora_layers=unet_lora_state_dict,
1063+
safe_serialization=True,
1064+
)
10151065
logger.info(f"Saved state to {save_path}")
10161066

10171067
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
@@ -1064,10 +1114,20 @@ def collate_fn(examples):
10641114
# Create the pipeline using the trained modules and save it.
10651115
accelerator.wait_for_everyone()
10661116
if accelerator.is_main_process:
1067-
unet = accelerator.unwrap_model(unet)
10681117
if args.use_ema:
10691118
ema_unet.copy_to(unet.parameters())
10701119

1120+
# store only LORA layers
1121+
unet = unet.to(torch.float32)
1122+
1123+
unwrapped_unet = unwrap_model(unet)
1124+
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
1125+
StableDiffusionInstructPix2PixPipeline.save_lora_weights(
1126+
save_directory=args.output_dir,
1127+
unet_lora_layers=unet_lora_state_dict,
1128+
safe_serialization=True,
1129+
)
1130+
10711131
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
10721132
args.pretrained_model_name_or_path,
10731133
text_encoder=unwrap_model(text_encoder),
@@ -1076,10 +1136,25 @@ def collate_fn(examples):
10761136
revision=args.revision,
10771137
variant=args.variant,
10781138
)
1079-
# store only LORA layers
1080-
unet.save_attn_procs(args.output_dir)
1139+
pipeline.load_lora_weights(args.output_dir)
1140+
1141+
images = None
1142+
if (args.val_image_url is not None) and (args.validation_prompt is not None):
1143+
images = log_validation(
1144+
pipeline,
1145+
args,
1146+
accelerator,
1147+
generator,
1148+
)
10811149

10821150
if args.push_to_hub:
1151+
save_model_card(
1152+
repo_id,
1153+
images=images,
1154+
base_model=args.pretrained_model_name_or_path,
1155+
dataset_name=args.dataset_name,
1156+
repo_folder=args.output_dir,
1157+
)
10831158
upload_folder(
10841159
repo_id=repo_id,
10851160
folder_path=args.output_dir,

0 commit comments

Comments
 (0)