1616
1717try :
1818 from extensions .sd_dreambooth_extension .dreambooth import shared as shared
19- from extensions .sd_dreambooth_extension .dreambooth .dataclasses .db_config import from_file
19+ from extensions .sd_dreambooth_extension .dreambooth .dataclasses .db_config import from_file , DreamboothConfig
2020 from extensions .sd_dreambooth_extension .dreambooth .shared import status
2121 from extensions .sd_dreambooth_extension .dreambooth .utils .model_utils import unload_system_models , \
2222 reload_system_models , \
2626 from extensions .sd_dreambooth_extension .lora_diffusion .lora import merge_lora_to_model
2727except :
2828 from dreambooth .dreambooth import shared as shared # noqa
29- from dreambooth .dreambooth .dataclasses .db_config import from_file # noqa
29+ from dreambooth .dreambooth .dataclasses .db_config import from_file , DreamboothConfig # noqa
3030 from dreambooth .dreambooth .shared import status # noqa
3131 from dreambooth .dreambooth .utils .model_utils import unload_system_models , reload_system_models , \
3232 disable_safe_unpickle , enable_safe_unpickle , import_model_class_from_model_name_or_path # noqa
@@ -338,13 +338,13 @@ def get_model_path(working_dir: str, model_name: str = "", file_extra: str = "")
338338 return None
339339
340340
341- def compile_checkpoint (model_name : str , lora_path : str = None , reload_models : bool = True , log : bool = True ,
341+ def compile_checkpoint (model_name : str , lora_file_name : str = None , reload_models : bool = True , log : bool = True ,
342342 snap_rev : str = "" ):
343343 """
344344
345345 @param model_name: The model name to compile
346346 @param reload_models: Whether to reload the system list of checkpoints.
347- @param lora_path : The path to a lora pt file to merge with the unet. Auto set during training.
347+ @param lora_file_name : The path to a lora pt file to merge with the unet. Auto set during training.
348348 @param log: Whether to print messages to console/UI.
349349 @param snap_rev: The revision of snapshot to load from
350350 @return: status: What happened, path: Checkpoint path
@@ -355,8 +355,8 @@ def compile_checkpoint(model_name: str, lora_path: str = None, reload_models: bo
355355 status .job_count = 7
356356
357357 config = from_file (model_name )
358- if lora_path is None and config .lora_model_name :
359- lora_path = config .lora_model_name
358+ if lora_file_name is None and config .lora_model_name :
359+ lora_file_name = config .lora_model_name
360360 save_model_name = model_name if config .custom_model_name == "" else config .custom_model_name
361361 if config .custom_model_name == "" :
362362 printi (f"Compiling checkpoint for { model_name } ..." , log = log )
@@ -418,10 +418,9 @@ def compile_checkpoint(model_name: str, lora_path: str = None, reload_models: bo
418418 pass
419419
420420 # Apply LoRA to the unet
421- if lora_path is not None and lora_path != "" :
421+ if lora_file_name is not None and lora_file_name != "" :
422422 unet_model = UNet2DConditionModel ().from_pretrained (os .path .dirname (unet_path ))
423- lora_rev = apply_lora (unet_model , lora_path , config .lora_unet_rank , config .lora_weight , "cpu" , False ,
424- config .use_lora_extended )
423+ lora_rev = apply_lora (config , unet_model , lora_file_name , "cpu" , False )
425424 unet_state_dict = copy .deepcopy (unet_model .state_dict ())
426425 del unet_model
427426 if lora_rev is not None :
@@ -448,9 +447,9 @@ def compile_checkpoint(model_name: str, lora_path: str = None, reload_models: bo
448447 printi ("Converting text encoder..." , log = log )
449448
450449 # Apply lora weights to the tenc
451- if lora_path is not None and lora_path != "" :
452- lora_paths = lora_path .split ("." )
453- lora_txt_path = f"{ lora_paths [0 ]} _txt.{ lora_paths [1 ]} "
450+ if lora_file_name is not None and lora_file_name != "" :
451+ lora_paths = lora_file_name .split ("." )
452+ lora_txt_file_name = f"{ lora_paths [0 ]} _txt.{ lora_paths [1 ]} "
454453 text_encoder_cls = import_model_class_from_model_name_or_path (config .pretrained_model_name_or_path ,
455454 config .revision )
456455
@@ -461,8 +460,7 @@ def compile_checkpoint(model_name: str, lora_path: str = None, reload_models: bo
461460 torch_dtype = torch .float32
462461 )
463462
464- apply_lora (text_encoder , lora_txt_path , config .lora_txt_rank , config .lora_txt_weight , "cpu" , True ,
465- config .use_lora_extended )
463+ apply_lora (config , text_encoder , lora_txt_file_name , "cpu" , True )
466464 text_enc_dict = copy .deepcopy (text_encoder .state_dict ())
467465 del text_encoder
468466 else :
@@ -551,20 +549,15 @@ def load_model(model_path: str, map_location: str):
551549 return loaded
552550
553551
554- def apply_lora (model : nn .Module , loras : str , rank : int , weight : float , device : str , is_tenc : bool , use_extended : bool ):
552+ def apply_lora (config : DreamboothConfig , model : nn .Module , lora_file_name : str , device : str , is_tenc : bool ):
555553 lora_rev = None
556- if loras is not None and loras != "" :
557- if not os .path .exists (loras ):
558- try :
559- cmd_lora_models_path = shared .lora_models_path
560- except :
561- cmd_lora_models_path = None
562- model_dir = os .path .dirname (cmd_lora_models_path ) if cmd_lora_models_path else shared .models_path
563- loras = os .path .join (model_dir , "lora" , loras )
564-
565- if os .path .exists (loras ):
566- lora_rev = loras .split ("_" )[- 1 ].replace (".pt" , "" )
567- printi (f"Loading lora from { loras } " , log = True )
568- merge_lora_to_model (model , load_model (loras , device ), is_tenc , use_extended , rank , weight )
554+ if lora_file_name is not None and lora_file_name != "" :
555+ if not os .path .exists (lora_file_name ):
556+ lora_file_name = os .path .join (config .model_dir , "loras" , lora_file_name )
557+ if os .path .exists (lora_file_name ):
558+ lora_rev = lora_file_name .split ("_" )[- 1 ].replace (".pt" , "" )
559+ printi (f"Loading lora from { lora_file_name } " , log = True )
560+ merge_lora_to_model (model , load_model (lora_file_name , device ), is_tenc , config .use_lora_extended ,
561+ config .lora_unet_rank , config .lora_weight )
569562
570563 return lora_rev
0 commit comments