@@ -654,7 +654,9 @@ def parse_args(input_args=None):
654654        "uses the value of square root of beta2. Ignored if optimizer is adamW" ,
655655    )
656656    parser .add_argument ("--prodigy_decouple" , type = bool , default = True , help = "Use AdamW style decoupled weight decay" )
657-     parser .add_argument ("--adam_weight_decay" , type = float , default = 1e-04 , help = "Weight decay to use for transformer params" )
657+     parser .add_argument (
658+         "--adam_weight_decay" , type = float , default = 1e-04 , help = "Weight decay to use for transformer params" 
659+     )
658660    parser .add_argument (
659661        "--adam_weight_decay_text_encoder" , type = float , default = 1e-03 , help = "Weight decay to use for text_encoder" 
660662    )
@@ -1506,7 +1508,7 @@ def main(args):
15061508    if  args .train_text_encoder_ti :
15071509        # we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK, 
15081510        # TOK2" -> ["TOK", "TOK2"] etc. 
1509-         token_abstraction_list  =  [place_holder .strip () for  place_holder  in  re .split (r' ,\s*'  , args .token_abstraction )]
1511+         token_abstraction_list  =  [place_holder .strip () for  place_holder  in  re .split (r" ,\s*"  , args .token_abstraction )]
15101512        logger .info (f"list of token identifiers: { token_abstraction_list }  )
15111513
15121514        if  args .initializer_concept  is  None :
@@ -1534,8 +1536,10 @@ def main(args):
15341536        for  token_abs , token_replacement  in  token_abstraction_dict .items ():
15351537            new_instance_prompt  =  args .instance_prompt .replace (token_abs , "" .join (token_replacement ))
15361538            if  args .instance_prompt  ==  new_instance_prompt :
1537-                 logger .warning ("Note! the instance prompt provided in --instance_prompt does not include the token abstraction specified " 
1538-                                "--token_abstraction. This may lead to incorrect optimization of text embeddings during pivotal tuning" )
1539+                 logger .warning (
1540+                     "Note! the instance prompt provided in --instance_prompt does not include the token abstraction specified " 
1541+                     "--token_abstraction. This may lead to incorrect optimization of text embeddings during pivotal tuning" 
1542+                 )
15391543            args .instance_prompt  =  new_instance_prompt 
15401544            if  args .with_prior_preservation :
15411545                args .class_prompt  =  args .class_prompt .replace (token_abs , "" .join (token_replacement ))
0 commit comments