3737from accelerate .utils import DistributedDataParallelKwargs , ProjectConfiguration , set_seed
3838from huggingface_hub import create_repo , upload_folder
3939from packaging import version
40+ from peft import LoraConfig
41+ from peft .utils import get_peft_model_state_dict
4042from PIL import Image
4143from PIL .ImageOps import exif_transpose
4244from safetensors .torch import save_file
5456 UNet2DConditionModel ,
5557)
5658from diffusers .loaders import LoraLoaderMixin
57- from diffusers .models .lora import LoRALinearLayer
5859from diffusers .optimization import get_scheduler
59- from diffusers .training_utils import compute_snr , unet_lora_state_dict
60- from diffusers .utils import check_min_version , is_wandb_available
60+ from diffusers .training_utils import compute_snr
61+ from diffusers .utils import check_min_version , convert_state_dict_to_diffusers , is_wandb_available
6162from diffusers .utils .import_utils import is_xformers_available
6263
6364
6768logger = get_logger (__name__ )
6869
6970
70- # TODO: This function should be removed once training scripts are rewritten in PEFT
71- def text_encoder_lora_state_dict (text_encoder ):
72- state_dict = {}
73-
74- def text_encoder_attn_modules (text_encoder ):
75- from transformers import CLIPTextModel , CLIPTextModelWithProjection
76-
77- attn_modules = []
78-
79- if isinstance (text_encoder , (CLIPTextModel , CLIPTextModelWithProjection )):
80- for i , layer in enumerate (text_encoder .text_model .encoder .layers ):
81- name = f"text_model.encoder.layers.{ i } .self_attn"
82- mod = layer .self_attn
83- attn_modules .append ((name , mod ))
84-
85- return attn_modules
86-
87- for name , module in text_encoder_attn_modules (text_encoder ):
88- for k , v in module .q_proj .lora_linear_layer .state_dict ().items ():
89- state_dict [f"{ name } .q_proj.lora_linear_layer.{ k } " ] = v
90-
91- for k , v in module .k_proj .lora_linear_layer .state_dict ().items ():
92- state_dict [f"{ name } .k_proj.lora_linear_layer.{ k } " ] = v
93-
94- for k , v in module .v_proj .lora_linear_layer .state_dict ().items ():
95- state_dict [f"{ name } .v_proj.lora_linear_layer.{ k } " ] = v
96-
97- for k , v in module .out_proj .lora_linear_layer .state_dict ().items ():
98- state_dict [f"{ name } .out_proj.lora_linear_layer.{ k } " ] = v
99-
100- return state_dict
101-
102-
10371def save_model_card (
10472 repo_id : str ,
10573 images = None ,
@@ -161,8 +129,6 @@ def save_model_card(
161129base_model: { base_model }
162130instance_prompt: { instance_prompt }
163131license: openrail++
164- widget:
165- - text: '{ validation_prompt if validation_prompt else instance_prompt } '
166132---
167133"""
168134
@@ -1264,54 +1230,25 @@ def main(args):
12641230 text_encoder_two .gradient_checkpointing_enable ()
12651231
12661232 # now we will add new LoRA weights to the attention layers
1267- # Set correct lora layers
1268- unet_lora_parameters = []
1269- for attn_processor_name , attn_processor in unet .attn_processors .items ():
1270- # Parse the attention module.
1271- attn_module = unet
1272- for n in attn_processor_name .split ("." )[:- 1 ]:
1273- attn_module = getattr (attn_module , n )
1274-
1275- # Set the `lora_layer` attribute of the attention-related matrices.
1276- attn_module .to_q .set_lora_layer (
1277- LoRALinearLayer (
1278- in_features = attn_module .to_q .in_features , out_features = attn_module .to_q .out_features , rank = args .rank
1279- )
1280- )
1281- attn_module .to_k .set_lora_layer (
1282- LoRALinearLayer (
1283- in_features = attn_module .to_k .in_features , out_features = attn_module .to_k .out_features , rank = args .rank
1284- )
1285- )
1286- attn_module .to_v .set_lora_layer (
1287- LoRALinearLayer (
1288- in_features = attn_module .to_v .in_features , out_features = attn_module .to_v .out_features , rank = args .rank
1289- )
1290- )
1291- attn_module .to_out [0 ].set_lora_layer (
1292- LoRALinearLayer (
1293- in_features = attn_module .to_out [0 ].in_features ,
1294- out_features = attn_module .to_out [0 ].out_features ,
1295- rank = args .rank ,
1296- )
1297- )
1298-
1299- # Accumulate the LoRA params to optimize.
1300- unet_lora_parameters .extend (attn_module .to_q .lora_layer .parameters ())
1301- unet_lora_parameters .extend (attn_module .to_k .lora_layer .parameters ())
1302- unet_lora_parameters .extend (attn_module .to_v .lora_layer .parameters ())
1303- unet_lora_parameters .extend (attn_module .to_out [0 ].lora_layer .parameters ())
1233+ unet_lora_config = LoraConfig (
1234+ r = args .rank ,
1235+ lora_alpha = args .rank ,
1236+ init_lora_weights = "gaussian" ,
1237+ target_modules = ["to_k" , "to_q" , "to_v" , "to_out.0" ],
1238+ )
1239+ unet .add_adapter (unet_lora_config )
13041240
13051241 # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
13061242 # So, instead, we monkey-patch the forward calls of its attention-blocks.
13071243 if args .train_text_encoder :
1308- # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
1309- text_lora_parameters_one = LoraLoaderMixin ._modify_text_encoder (
1310- text_encoder_one , dtype = torch .float32 , rank = args .rank
1311- )
1312- text_lora_parameters_two = LoraLoaderMixin ._modify_text_encoder (
1313- text_encoder_two , dtype = torch .float32 , rank = args .rank
1244+ text_lora_config = LoraConfig (
1245+ r = args .rank ,
1246+ lora_alpha = args .rank ,
1247+ init_lora_weights = "gaussian" ,
1248+ target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ],
13141249 )
1250+ text_encoder_one .add_adapter (text_lora_config )
1251+ text_encoder_two .add_adapter (text_lora_config )
13151252
13161253 # if we use textual inversion, we freeze all parameters except for the token embeddings
13171254 # in text encoder
@@ -1335,6 +1272,17 @@ def main(args):
13351272 else :
13361273 param .requires_grad = False
13371274
1275+ # Make sure the trainable params are in float32.
1276+ if args .mixed_precision == "fp16" :
1277+ models = [unet ]
1278+ if args .train_text_encoder :
1279+ models .extend ([text_encoder_one , text_encoder_two ])
1280+ for model in models :
1281+ for param in model .parameters ():
1282+ # only upcast trainable parameters (LoRA) into fp32
1283+ if param .requires_grad :
1284+ param .data = param .to (torch .float32 )
1285+
13381286 # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
13391287 def save_model_hook (models , weights , output_dir ):
13401288 if accelerator .is_main_process :
@@ -1346,11 +1294,15 @@ def save_model_hook(models, weights, output_dir):
13461294
13471295 for model in models :
13481296 if isinstance (model , type (accelerator .unwrap_model (unet ))):
1349- unet_lora_layers_to_save = unet_lora_state_dict ( model )
1297+ unet_lora_layers_to_save = convert_state_dict_to_diffusers ( get_peft_model_state_dict ( model ) )
13501298 elif isinstance (model , type (accelerator .unwrap_model (text_encoder_one ))):
1351- text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict (model )
1299+ text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers (
1300+ get_peft_model_state_dict (model )
1301+ )
13521302 elif isinstance (model , type (accelerator .unwrap_model (text_encoder_two ))):
1353- text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict (model )
1303+ text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers (
1304+ get_peft_model_state_dict (model )
1305+ )
13541306 else :
13551307 raise ValueError (f"unexpected save model: { model .__class__ } " )
13561308
@@ -1407,6 +1359,12 @@ def load_model_hook(models, input_dir):
14071359 args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
14081360 )
14091361
1362+ unet_lora_parameters = list (filter (lambda p : p .requires_grad , unet .parameters ()))
1363+
1364+ if args .train_text_encoder :
1365+ text_lora_parameters_one = list (filter (lambda p : p .requires_grad , text_encoder_one .parameters ()))
1366+ text_lora_parameters_two = list (filter (lambda p : p .requires_grad , text_encoder_two .parameters ()))
1367+
14101368 # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training
14111369 freeze_text_encoder = not (args .train_text_encoder or args .train_text_encoder_ti )
14121370
@@ -1997,13 +1955,17 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19971955 if accelerator .is_main_process :
19981956 unet = accelerator .unwrap_model (unet )
19991957 unet = unet .to (torch .float32 )
2000- unet_lora_layers = unet_lora_state_dict (unet )
1958+ unet_lora_layers = get_peft_model_state_dict (unet )
20011959
20021960 if args .train_text_encoder :
20031961 text_encoder_one = accelerator .unwrap_model (text_encoder_one )
2004- text_encoder_lora_layers = text_encoder_lora_state_dict (text_encoder_one .to (torch .float32 ))
1962+ text_encoder_lora_layers = convert_state_dict_to_diffusers (
1963+ get_peft_model_state_dict (text_encoder_one .to (torch .float32 ))
1964+ )
20051965 text_encoder_two = accelerator .unwrap_model (text_encoder_two )
2006- text_encoder_2_lora_layers = text_encoder_lora_state_dict (text_encoder_two .to (torch .float32 ))
1966+ text_encoder_2_lora_layers = convert_state_dict_to_diffusers (
1967+ get_peft_model_state_dict (text_encoder_two .to (torch .float32 ))
1968+ )
20071969 else :
20081970 text_encoder_lora_layers = None
20091971 text_encoder_2_lora_layers = None
0 commit comments