Skip to content

Commit c8ddd83

Browse files
committed
style
1 parent b791e13 commit c8ddd83

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,9 @@ def parse_args(input_args=None):
654654
"uses the value of square root of beta2. Ignored if optimizer is adamW",
655655
)
656656
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
657-
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for transformer params")
657+
parser.add_argument(
658+
"--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for transformer params"
659+
)
658660
parser.add_argument(
659661
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
660662
)
@@ -1506,7 +1508,7 @@ def main(args):
15061508
if args.train_text_encoder_ti:
15071509
# we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK,
15081510
# TOK2" -> ["TOK", "TOK2"] etc.
1509-
token_abstraction_list = [place_holder.strip() for place_holder in re.split(r',\s*', args.token_abstraction)]
1511+
token_abstraction_list = [place_holder.strip() for place_holder in re.split(r",\s*", args.token_abstraction)]
15101512
logger.info(f"list of token identifiers: {token_abstraction_list}")
15111513

15121514
if args.initializer_concept is None:
@@ -1534,8 +1536,10 @@ def main(args):
15341536
for token_abs, token_replacement in token_abstraction_dict.items():
15351537
new_instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement))
15361538
if args.instance_prompt == new_instance_prompt:
1537-
logger.warning("Note! the instance prompt provided in --instance_prompt does not include the token abstraction specified "
1538-
"--token_abstraction. This may lead to incorrect optimization of text embeddings during pivotal tuning")
1539+
logger.warning(
1540+
"Note! the instance prompt provided in --instance_prompt does not include the token abstraction specified "
1541+
"--token_abstraction. This may lead to incorrect optimization of text embeddings during pivotal tuning"
1542+
)
15391543
args.instance_prompt = new_instance_prompt
15401544
if args.with_prior_preservation:
15411545
args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement))

0 commit comments

Comments
 (0)