37
37
from accelerate .utils import DistributedDataParallelKwargs , ProjectConfiguration , set_seed
38
38
from huggingface_hub import create_repo , upload_folder
39
39
from packaging import version
40
+ from peft import LoraConfig
41
+ from peft .utils import get_peft_model_state_dict
40
42
from PIL import Image
41
43
from PIL .ImageOps import exif_transpose
42
44
from safetensors .torch import save_file
54
56
UNet2DConditionModel ,
55
57
)
56
58
from diffusers .loaders import LoraLoaderMixin
57
- from diffusers .models .lora import LoRALinearLayer
58
59
from diffusers .optimization import get_scheduler
59
- from diffusers .training_utils import compute_snr , unet_lora_state_dict
60
- from diffusers .utils import check_min_version , is_wandb_available
60
+ from diffusers .training_utils import compute_snr
61
+ from diffusers .utils import check_min_version , convert_state_dict_to_diffusers , is_wandb_available
61
62
from diffusers .utils .import_utils import is_xformers_available
62
63
63
64
67
68
logger = get_logger (__name__ )
68
69
69
70
70
- # TODO: This function should be removed once training scripts are rewritten in PEFT
71
- def text_encoder_lora_state_dict (text_encoder ):
72
- state_dict = {}
73
-
74
- def text_encoder_attn_modules (text_encoder ):
75
- from transformers import CLIPTextModel , CLIPTextModelWithProjection
76
-
77
- attn_modules = []
78
-
79
- if isinstance (text_encoder , (CLIPTextModel , CLIPTextModelWithProjection )):
80
- for i , layer in enumerate (text_encoder .text_model .encoder .layers ):
81
- name = f"text_model.encoder.layers.{ i } .self_attn"
82
- mod = layer .self_attn
83
- attn_modules .append ((name , mod ))
84
-
85
- return attn_modules
86
-
87
- for name , module in text_encoder_attn_modules (text_encoder ):
88
- for k , v in module .q_proj .lora_linear_layer .state_dict ().items ():
89
- state_dict [f"{ name } .q_proj.lora_linear_layer.{ k } " ] = v
90
-
91
- for k , v in module .k_proj .lora_linear_layer .state_dict ().items ():
92
- state_dict [f"{ name } .k_proj.lora_linear_layer.{ k } " ] = v
93
-
94
- for k , v in module .v_proj .lora_linear_layer .state_dict ().items ():
95
- state_dict [f"{ name } .v_proj.lora_linear_layer.{ k } " ] = v
96
-
97
- for k , v in module .out_proj .lora_linear_layer .state_dict ().items ():
98
- state_dict [f"{ name } .out_proj.lora_linear_layer.{ k } " ] = v
99
-
100
- return state_dict
101
-
102
-
103
71
def save_model_card (
104
72
repo_id : str ,
105
73
images = None ,
@@ -161,8 +129,6 @@ def save_model_card(
161
129
base_model: { base_model }
162
130
instance_prompt: { instance_prompt }
163
131
license: openrail++
164
- widget:
165
- - text: '{ validation_prompt if validation_prompt else instance_prompt } '
166
132
---
167
133
"""
168
134
@@ -1264,54 +1230,25 @@ def main(args):
1264
1230
text_encoder_two .gradient_checkpointing_enable ()
1265
1231
1266
1232
# now we will add new LoRA weights to the attention layers
1267
- # Set correct lora layers
1268
- unet_lora_parameters = []
1269
- for attn_processor_name , attn_processor in unet .attn_processors .items ():
1270
- # Parse the attention module.
1271
- attn_module = unet
1272
- for n in attn_processor_name .split ("." )[:- 1 ]:
1273
- attn_module = getattr (attn_module , n )
1274
-
1275
- # Set the `lora_layer` attribute of the attention-related matrices.
1276
- attn_module .to_q .set_lora_layer (
1277
- LoRALinearLayer (
1278
- in_features = attn_module .to_q .in_features , out_features = attn_module .to_q .out_features , rank = args .rank
1279
- )
1280
- )
1281
- attn_module .to_k .set_lora_layer (
1282
- LoRALinearLayer (
1283
- in_features = attn_module .to_k .in_features , out_features = attn_module .to_k .out_features , rank = args .rank
1284
- )
1285
- )
1286
- attn_module .to_v .set_lora_layer (
1287
- LoRALinearLayer (
1288
- in_features = attn_module .to_v .in_features , out_features = attn_module .to_v .out_features , rank = args .rank
1289
- )
1290
- )
1291
- attn_module .to_out [0 ].set_lora_layer (
1292
- LoRALinearLayer (
1293
- in_features = attn_module .to_out [0 ].in_features ,
1294
- out_features = attn_module .to_out [0 ].out_features ,
1295
- rank = args .rank ,
1296
- )
1297
- )
1298
-
1299
- # Accumulate the LoRA params to optimize.
1300
- unet_lora_parameters .extend (attn_module .to_q .lora_layer .parameters ())
1301
- unet_lora_parameters .extend (attn_module .to_k .lora_layer .parameters ())
1302
- unet_lora_parameters .extend (attn_module .to_v .lora_layer .parameters ())
1303
- unet_lora_parameters .extend (attn_module .to_out [0 ].lora_layer .parameters ())
1233
+ unet_lora_config = LoraConfig (
1234
+ r = args .rank ,
1235
+ lora_alpha = args .rank ,
1236
+ init_lora_weights = "gaussian" ,
1237
+ target_modules = ["to_k" , "to_q" , "to_v" , "to_out.0" ],
1238
+ )
1239
+ unet .add_adapter (unet_lora_config )
1304
1240
1305
1241
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
1306
1242
# So, instead, we monkey-patch the forward calls of its attention-blocks.
1307
1243
if args .train_text_encoder :
1308
- # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
1309
- text_lora_parameters_one = LoraLoaderMixin ._modify_text_encoder (
1310
- text_encoder_one , dtype = torch .float32 , rank = args .rank
1311
- )
1312
- text_lora_parameters_two = LoraLoaderMixin ._modify_text_encoder (
1313
- text_encoder_two , dtype = torch .float32 , rank = args .rank
1244
+ text_lora_config = LoraConfig (
1245
+ r = args .rank ,
1246
+ lora_alpha = args .rank ,
1247
+ init_lora_weights = "gaussian" ,
1248
+ target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ],
1314
1249
)
1250
+ text_encoder_one .add_adapter (text_lora_config )
1251
+ text_encoder_two .add_adapter (text_lora_config )
1315
1252
1316
1253
# if we use textual inversion, we freeze all parameters except for the token embeddings
1317
1254
# in text encoder
@@ -1335,6 +1272,17 @@ def main(args):
1335
1272
else :
1336
1273
param .requires_grad = False
1337
1274
1275
+ # Make sure the trainable params are in float32.
1276
+ if args .mixed_precision == "fp16" :
1277
+ models = [unet ]
1278
+ if args .train_text_encoder :
1279
+ models .extend ([text_encoder_one , text_encoder_two ])
1280
+ for model in models :
1281
+ for param in model .parameters ():
1282
+ # only upcast trainable parameters (LoRA) into fp32
1283
+ if param .requires_grad :
1284
+ param .data = param .to (torch .float32 )
1285
+
1338
1286
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
1339
1287
def save_model_hook (models , weights , output_dir ):
1340
1288
if accelerator .is_main_process :
@@ -1346,11 +1294,15 @@ def save_model_hook(models, weights, output_dir):
1346
1294
1347
1295
for model in models :
1348
1296
if isinstance (model , type (accelerator .unwrap_model (unet ))):
1349
- unet_lora_layers_to_save = unet_lora_state_dict ( model )
1297
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers ( get_peft_model_state_dict ( model ) )
1350
1298
elif isinstance (model , type (accelerator .unwrap_model (text_encoder_one ))):
1351
- text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict (model )
1299
+ text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers (
1300
+ get_peft_model_state_dict (model )
1301
+ )
1352
1302
elif isinstance (model , type (accelerator .unwrap_model (text_encoder_two ))):
1353
- text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict (model )
1303
+ text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers (
1304
+ get_peft_model_state_dict (model )
1305
+ )
1354
1306
else :
1355
1307
raise ValueError (f"unexpected save model: { model .__class__ } " )
1356
1308
@@ -1407,6 +1359,12 @@ def load_model_hook(models, input_dir):
1407
1359
args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
1408
1360
)
1409
1361
1362
+ unet_lora_parameters = list (filter (lambda p : p .requires_grad , unet .parameters ()))
1363
+
1364
+ if args .train_text_encoder :
1365
+ text_lora_parameters_one = list (filter (lambda p : p .requires_grad , text_encoder_one .parameters ()))
1366
+ text_lora_parameters_two = list (filter (lambda p : p .requires_grad , text_encoder_two .parameters ()))
1367
+
1410
1368
# If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training
1411
1369
freeze_text_encoder = not (args .train_text_encoder or args .train_text_encoder_ti )
1412
1370
@@ -1997,13 +1955,17 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1997
1955
if accelerator .is_main_process :
1998
1956
unet = accelerator .unwrap_model (unet )
1999
1957
unet = unet .to (torch .float32 )
2000
- unet_lora_layers = unet_lora_state_dict (unet )
1958
+ unet_lora_layers = get_peft_model_state_dict (unet )
2001
1959
2002
1960
if args .train_text_encoder :
2003
1961
text_encoder_one = accelerator .unwrap_model (text_encoder_one )
2004
- text_encoder_lora_layers = text_encoder_lora_state_dict (text_encoder_one .to (torch .float32 ))
1962
+ text_encoder_lora_layers = convert_state_dict_to_diffusers (
1963
+ get_peft_model_state_dict (text_encoder_one .to (torch .float32 ))
1964
+ )
2005
1965
text_encoder_two = accelerator .unwrap_model (text_encoder_two )
2006
- text_encoder_2_lora_layers = text_encoder_lora_state_dict (text_encoder_two .to (torch .float32 ))
1966
+ text_encoder_2_lora_layers = convert_state_dict_to_diffusers (
1967
+ get_peft_model_state_dict (text_encoder_two .to (torch .float32 ))
1968
+ )
2007
1969
else :
2008
1970
text_encoder_lora_layers = None
2009
1971
text_encoder_2_lora_layers = None
0 commit comments