Skip to content

Commit 645a62b

Browse files
apolinariomultimodalart
andauthored
Add PEFT to advanced training script (#6294)
* Fix ProdigyOPT in SDXL Dreambooth script * style * style * Add PEFT to Advanced Training Script * style * style * ✨ style ✨ * change order for logic operation * add lora alpha * style * Align PEFT to new format * Update train_dreambooth_lora_sdxl_advanced.py Apply #6355 fix --------- Co-authored-by: multimodalart <[email protected]>
1 parent 6414d4e commit 645a62b

File tree

1 file changed

+49
-87
lines changed

1 file changed

+49
-87
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 49 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
3838
from huggingface_hub import create_repo, upload_folder
3939
from packaging import version
40+
from peft import LoraConfig
41+
from peft.utils import get_peft_model_state_dict
4042
from PIL import Image
4143
from PIL.ImageOps import exif_transpose
4244
from safetensors.torch import save_file
@@ -54,10 +56,9 @@
5456
UNet2DConditionModel,
5557
)
5658
from diffusers.loaders import LoraLoaderMixin
57-
from diffusers.models.lora import LoRALinearLayer
5859
from diffusers.optimization import get_scheduler
59-
from diffusers.training_utils import compute_snr, unet_lora_state_dict
60-
from diffusers.utils import check_min_version, is_wandb_available
60+
from diffusers.training_utils import compute_snr
61+
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
6162
from diffusers.utils.import_utils import is_xformers_available
6263

6364

@@ -67,39 +68,6 @@
6768
logger = get_logger(__name__)
6869

6970

70-
# TODO: This function should be removed once training scripts are rewritten in PEFT
71-
def text_encoder_lora_state_dict(text_encoder):
72-
state_dict = {}
73-
74-
def text_encoder_attn_modules(text_encoder):
75-
from transformers import CLIPTextModel, CLIPTextModelWithProjection
76-
77-
attn_modules = []
78-
79-
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
80-
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
81-
name = f"text_model.encoder.layers.{i}.self_attn"
82-
mod = layer.self_attn
83-
attn_modules.append((name, mod))
84-
85-
return attn_modules
86-
87-
for name, module in text_encoder_attn_modules(text_encoder):
88-
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
89-
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
90-
91-
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
92-
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
93-
94-
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
95-
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
96-
97-
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
98-
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
99-
100-
return state_dict
101-
102-
10371
def save_model_card(
10472
repo_id: str,
10573
images=None,
@@ -161,8 +129,6 @@ def save_model_card(
161129
base_model: {base_model}
162130
instance_prompt: {instance_prompt}
163131
license: openrail++
164-
widget:
165-
- text: '{validation_prompt if validation_prompt else instance_prompt}'
166132
---
167133
"""
168134

@@ -1264,54 +1230,25 @@ def main(args):
12641230
text_encoder_two.gradient_checkpointing_enable()
12651231

12661232
# now we will add new LoRA weights to the attention layers
1267-
# Set correct lora layers
1268-
unet_lora_parameters = []
1269-
for attn_processor_name, attn_processor in unet.attn_processors.items():
1270-
# Parse the attention module.
1271-
attn_module = unet
1272-
for n in attn_processor_name.split(".")[:-1]:
1273-
attn_module = getattr(attn_module, n)
1274-
1275-
# Set the `lora_layer` attribute of the attention-related matrices.
1276-
attn_module.to_q.set_lora_layer(
1277-
LoRALinearLayer(
1278-
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
1279-
)
1280-
)
1281-
attn_module.to_k.set_lora_layer(
1282-
LoRALinearLayer(
1283-
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
1284-
)
1285-
)
1286-
attn_module.to_v.set_lora_layer(
1287-
LoRALinearLayer(
1288-
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
1289-
)
1290-
)
1291-
attn_module.to_out[0].set_lora_layer(
1292-
LoRALinearLayer(
1293-
in_features=attn_module.to_out[0].in_features,
1294-
out_features=attn_module.to_out[0].out_features,
1295-
rank=args.rank,
1296-
)
1297-
)
1298-
1299-
# Accumulate the LoRA params to optimize.
1300-
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
1301-
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
1302-
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
1303-
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
1233+
unet_lora_config = LoraConfig(
1234+
r=args.rank,
1235+
lora_alpha=args.rank,
1236+
init_lora_weights="gaussian",
1237+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
1238+
)
1239+
unet.add_adapter(unet_lora_config)
13041240

13051241
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
13061242
# So, instead, we monkey-patch the forward calls of its attention-blocks.
13071243
if args.train_text_encoder:
1308-
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
1309-
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
1310-
text_encoder_one, dtype=torch.float32, rank=args.rank
1311-
)
1312-
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
1313-
text_encoder_two, dtype=torch.float32, rank=args.rank
1244+
text_lora_config = LoraConfig(
1245+
r=args.rank,
1246+
lora_alpha=args.rank,
1247+
init_lora_weights="gaussian",
1248+
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
13141249
)
1250+
text_encoder_one.add_adapter(text_lora_config)
1251+
text_encoder_two.add_adapter(text_lora_config)
13151252

13161253
# if we use textual inversion, we freeze all parameters except for the token embeddings
13171254
# in text encoder
@@ -1335,6 +1272,17 @@ def main(args):
13351272
else:
13361273
param.requires_grad = False
13371274

1275+
# Make sure the trainable params are in float32.
1276+
if args.mixed_precision == "fp16":
1277+
models = [unet]
1278+
if args.train_text_encoder:
1279+
models.extend([text_encoder_one, text_encoder_two])
1280+
for model in models:
1281+
for param in model.parameters():
1282+
# only upcast trainable parameters (LoRA) into fp32
1283+
if param.requires_grad:
1284+
param.data = param.to(torch.float32)
1285+
13381286
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
13391287
def save_model_hook(models, weights, output_dir):
13401288
if accelerator.is_main_process:
@@ -1346,11 +1294,15 @@ def save_model_hook(models, weights, output_dir):
13461294

13471295
for model in models:
13481296
if isinstance(model, type(accelerator.unwrap_model(unet))):
1349-
unet_lora_layers_to_save = unet_lora_state_dict(model)
1297+
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
13501298
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
1351-
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
1299+
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
1300+
get_peft_model_state_dict(model)
1301+
)
13521302
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
1353-
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
1303+
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
1304+
get_peft_model_state_dict(model)
1305+
)
13541306
else:
13551307
raise ValueError(f"unexpected save model: {model.__class__}")
13561308

@@ -1407,6 +1359,12 @@ def load_model_hook(models, input_dir):
14071359
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
14081360
)
14091361

1362+
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
1363+
1364+
if args.train_text_encoder:
1365+
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
1366+
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
1367+
14101368
# If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training
14111369
freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti)
14121370

@@ -1997,13 +1955,17 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19971955
if accelerator.is_main_process:
19981956
unet = accelerator.unwrap_model(unet)
19991957
unet = unet.to(torch.float32)
2000-
unet_lora_layers = unet_lora_state_dict(unet)
1958+
unet_lora_layers = get_peft_model_state_dict(unet)
20011959

20021960
if args.train_text_encoder:
20031961
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
2004-
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
1962+
text_encoder_lora_layers = convert_state_dict_to_diffusers(
1963+
get_peft_model_state_dict(text_encoder_one.to(torch.float32))
1964+
)
20051965
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
2006-
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
1966+
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
1967+
get_peft_model_state_dict(text_encoder_two.to(torch.float32))
1968+
)
20071969
else:
20081970
text_encoder_lora_layers = None
20091971
text_encoder_2_lora_layers = None

0 commit comments

Comments
 (0)