@@ -659,6 +659,9 @@ def parse_args(input_args=None):
659659 default = 4 ,
660660 help = ("The dimension of the LoRA update matrices." ),
661661 )
662+
663+ parser .add_argument ("--lora_dropout" , type = float , default = 0.0 , help = "Dropout probability for LoRA layers" )
664+
662665 parser .add_argument (
663666 "--use_dora" ,
664667 action = "store_true" ,
@@ -1199,10 +1202,11 @@ def main(args):
11991202 text_encoder_one .gradient_checkpointing_enable ()
12001203 text_encoder_two .gradient_checkpointing_enable ()
12011204
1202- def get_lora_config (rank , use_dora , target_modules ):
1205+ def get_lora_config (rank , dropout , use_dora , target_modules ):
12031206 base_config = {
12041207 "r" : rank ,
12051208 "lora_alpha" : rank ,
1209+ "lora_dropout" : dropout ,
12061210 "init_lora_weights" : "gaussian" ,
12071211 "target_modules" : target_modules ,
12081212 }
@@ -1218,14 +1222,24 @@ def get_lora_config(rank, use_dora, target_modules):
12181222
12191223 # now we will add new LoRA weights to the attention layers
12201224 unet_target_modules = ["to_k" , "to_q" , "to_v" , "to_out.0" ]
1221- unet_lora_config = get_lora_config (rank = args .rank , use_dora = args .use_dora , target_modules = unet_target_modules )
1225+ unet_lora_config = get_lora_config (
1226+ rank = args .rank ,
1227+ dropout = args .lora_dropout ,
1228+ use_dora = args .use_dora ,
1229+ target_modules = unet_target_modules ,
1230+ )
12221231 unet .add_adapter (unet_lora_config )
12231232
12241233 # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
12251234 # So, instead, we monkey-patch the forward calls of its attention-blocks.
12261235 if args .train_text_encoder :
12271236 text_target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ]
1228- text_lora_config = get_lora_config (rank = args .rank , use_dora = args .use_dora , target_modules = text_target_modules )
1237+ text_lora_config = get_lora_config (
1238+ rank = args .rank ,
1239+ dropout = args .lora_dropout ,
1240+ use_dora = args .use_dora ,
1241+ target_modules = text_target_modules ,
1242+ )
12291243 text_encoder_one .add_adapter (text_lora_config )
12301244 text_encoder_two .add_adapter (text_lora_config )
12311245
0 commit comments