Skip to content

Commit 936cbeb

Browse files
committed
fix style and quality
1 parent 6f6aa1b commit 936cbeb

File tree

1 file changed

+7
-15
lines changed

1 file changed

+7
-15
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@
6767
convert_state_dict_to_diffusers,
6868
convert_state_dict_to_kohya,
6969
convert_unet_state_dict_to_peft,
70+
is_peft_version,
7071
is_wandb_available,
71-
is_peft_version
7272
)
7373
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
7474
from diffusers.utils.import_utils import is_xformers_available
@@ -1193,31 +1193,23 @@ def get_lora_config(rank, use_dora, target_modules):
11931193
}
11941194
if use_dora and is_peft_version("<", "0.9.0"):
11951195
raise ValueError(
1196-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1197-
)
1196+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1197+
)
11981198
else:
11991199
base_config["use_dora"] = True
1200-
1200+
12011201
return LoraConfig(**base_config)
1202-
1202+
12031203
# now we will add new LoRA weights to the attention layers
12041204
unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
1205-
unet_lora_config = get_lora_config(
1206-
rank=args.rank,
1207-
use_dora=args.use_dora,
1208-
target_modules=unet_target_modules
1209-
)
1205+
unet_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=unet_target_modules)
12101206
unet.add_adapter(unet_lora_config)
12111207

12121208
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
12131209
# So, instead, we monkey-patch the forward calls of its attention-blocks.
12141210
if args.train_text_encoder:
12151211
text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
1216-
text_lora_config = get_lora_config(
1217-
rank=args.rank,
1218-
use_dora=args.use_dora,
1219-
target_modules=text_target_modules
1220-
)
1212+
text_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=text_target_modules)
12211213
text_encoder_one.add_adapter(text_lora_config)
12221214
text_encoder_two.add_adapter(text_lora_config)
12231215

0 commit comments

Comments
 (0)