1313# limitations under the License.
1414import os
1515import re
16- import warnings
1716from collections import defaultdict
1817from contextlib import nullcontext
1918from io import BytesIO
@@ -307,6 +306,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
307306 # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
308307 # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
309308 network_alphas = kwargs .pop ("network_alphas" , None )
309+
310+ _pipeline = kwargs .pop ("_pipeline" , None )
311+
310312 is_network_alphas_none = network_alphas is None
311313
312314 allow_pickle = False
@@ -460,6 +462,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
460462 load_model_dict_into_meta (lora , value_dict , device = device , dtype = dtype )
461463 else :
462464 lora .load_state_dict (value_dict )
465+
463466 elif is_custom_diffusion :
464467 attn_processors = {}
465468 custom_diffusion_grouped_dict = defaultdict (dict )
@@ -489,19 +492,44 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
489492 cross_attention_dim = cross_attention_dim ,
490493 )
491494 attn_processors [key ].load_state_dict (value_dict )
492-
493- self .set_attn_processor (attn_processors )
494495 else :
495496 raise ValueError (
496497 f"{ model_file } does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
497498 )
498499
500+ # <Unsafe code
501+ # We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
502+ # Now we remove any existing hooks to
503+ is_model_cpu_offload = False
504+ is_sequential_cpu_offload = False
505+ if _pipeline is not None :
506+ for _ , component in _pipeline .components .items ():
507+ if isinstance (component , nn .Module ):
508+ if hasattr (component , "_hf_hook" ):
509+ is_model_cpu_offload = isinstance (getattr (component , "_hf_hook" ), CpuOffload )
510+ is_sequential_cpu_offload = isinstance (getattr (component , "_hf_hook" ), AlignDevicesHook )
511+ logger .info (
512+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
513+ )
514+ remove_hook_from_module (component , recurse = is_sequential_cpu_offload )
515+
516+ # only custom diffusion needs to set attn processors
517+ if is_custom_diffusion :
518+ self .set_attn_processor (attn_processors )
519+
499520 # set lora layers
500521 for target_module , lora_layer in lora_layers_list :
501522 target_module .set_lora_layer (lora_layer )
502523
503524 self .to (dtype = self .dtype , device = self .device )
504525
526+ # Offload back.
527+ if is_model_cpu_offload :
528+ _pipeline .enable_model_cpu_offload ()
529+ elif is_sequential_cpu_offload :
530+ _pipeline .enable_sequential_cpu_offload ()
531+ # Unsafe code />
532+
505533 def convert_state_dict_legacy_attn_format (self , state_dict , network_alphas ):
506534 is_new_lora_format = all (
507535 key .startswith (self .unet_name ) or key .startswith (self .text_encoder_name ) for key in state_dict .keys ()
@@ -1060,41 +1088,31 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
10601088 kwargs (`dict`, *optional*):
10611089 See [`~loaders.LoraLoaderMixin.lora_state_dict`].
10621090 """
1063- # Remove any existing hooks.
1064- is_model_cpu_offload = False
1065- is_sequential_cpu_offload = False
1066- recurive = False
1067- for _ , component in self .components .items ():
1068- if isinstance (component , nn .Module ):
1069- if hasattr (component , "_hf_hook" ):
1070- is_model_cpu_offload = isinstance (getattr (component , "_hf_hook" ), CpuOffload )
1071- is_sequential_cpu_offload = isinstance (getattr (component , "_hf_hook" ), AlignDevicesHook )
1072- logger .info (
1073- "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
1074- )
1075- recurive = is_sequential_cpu_offload
1076- remove_hook_from_module (component , recurse = recurive )
1091+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1092+ state_dict , network_alphas = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
1093+
1094+ is_correct_format = all ("lora" in key for key in state_dict .keys ())
1095+ if not is_correct_format :
1096+ raise ValueError ("Invalid LoRA checkpoint." )
10771097
10781098 low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT )
10791099
1080- state_dict , network_alphas = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
10811100 self .load_lora_into_unet (
1082- state_dict , network_alphas = network_alphas , unet = self .unet , low_cpu_mem_usage = low_cpu_mem_usage
1101+ state_dict ,
1102+ network_alphas = network_alphas ,
1103+ unet = self .unet ,
1104+ low_cpu_mem_usage = low_cpu_mem_usage ,
1105+ _pipeline = self ,
10831106 )
10841107 self .load_lora_into_text_encoder (
10851108 state_dict ,
10861109 network_alphas = network_alphas ,
10871110 text_encoder = self .text_encoder ,
10881111 lora_scale = self .lora_scale ,
10891112 low_cpu_mem_usage = low_cpu_mem_usage ,
1113+ _pipeline = self ,
10901114 )
10911115
1092- # Offload back.
1093- if is_model_cpu_offload :
1094- self .enable_model_cpu_offload ()
1095- elif is_sequential_cpu_offload :
1096- self .enable_sequential_cpu_offload ()
1097-
10981116 @classmethod
10991117 def lora_state_dict (
11001118 cls ,
@@ -1391,7 +1409,7 @@ def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="
13911409 return new_state_dict
13921410
13931411 @classmethod
1394- def load_lora_into_unet (cls , state_dict , network_alphas , unet , low_cpu_mem_usage = None ):
1412+ def load_lora_into_unet (cls , state_dict , network_alphas , unet , low_cpu_mem_usage = None , _pipeline = None ):
13951413 """
13961414 This will load the LoRA layers specified in `state_dict` into `unet`.
13971415
@@ -1433,13 +1451,22 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage
14331451 # Otherwise, we're dealing with the old format. This means the `state_dict` should only
14341452 # contain the module names of the `unet` as its keys WITHOUT any prefix.
14351453 warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
1436- warnings .warn (warn_message )
1454+ logger .warn (warn_message )
14371455
1438- unet .load_attn_procs (state_dict , network_alphas = network_alphas , low_cpu_mem_usage = low_cpu_mem_usage )
1456+ unet .load_attn_procs (
1457+ state_dict , network_alphas = network_alphas , low_cpu_mem_usage = low_cpu_mem_usage , _pipeline = _pipeline
1458+ )
14391459
14401460 @classmethod
14411461 def load_lora_into_text_encoder (
1442- cls , state_dict , network_alphas , text_encoder , prefix = None , lora_scale = 1.0 , low_cpu_mem_usage = None
1462+ cls ,
1463+ state_dict ,
1464+ network_alphas ,
1465+ text_encoder ,
1466+ prefix = None ,
1467+ lora_scale = 1.0 ,
1468+ low_cpu_mem_usage = None ,
1469+ _pipeline = None ,
14431470 ):
14441471 """
14451472 This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1549,11 +1576,15 @@ def load_lora_into_text_encoder(
15491576 low_cpu_mem_usage = low_cpu_mem_usage ,
15501577 )
15511578
1552- # set correct dtype & device
1553- text_encoder_lora_state_dict = {
1554- k : v .to (device = text_encoder .device , dtype = text_encoder .dtype )
1555- for k , v in text_encoder_lora_state_dict .items ()
1556- }
1579+ is_pipeline_offloaded = _pipeline is not None and any (
1580+ isinstance (c , torch .nn .Module ) and hasattr (c , "_hf_hook" ) for c in _pipeline .components .values ()
1581+ )
1582+ if is_pipeline_offloaded and low_cpu_mem_usage :
1583+ low_cpu_mem_usage = True
1584+ logger .info (
1585+ f"Pipeline { _pipeline .__class__ } is offloaded. Therefore low cpu mem usage loading is forced."
1586+ )
1587+
15571588 if low_cpu_mem_usage :
15581589 device = next (iter (text_encoder_lora_state_dict .values ())).device
15591590 dtype = next (iter (text_encoder_lora_state_dict .values ())).dtype
@@ -1569,8 +1600,33 @@ def load_lora_into_text_encoder(
15691600 f"failed to load text encoder state dict, unexpected keys: { load_state_dict_results .unexpected_keys } "
15701601 )
15711602
1603+ # <Unsafe code
1604+ # We can be sure that the following works as all we do is change the dtype and device of the text encoder
1605+ # Now we remove any existing hooks to
1606+ is_model_cpu_offload = False
1607+ is_sequential_cpu_offload = False
1608+ if _pipeline is not None :
1609+ for _ , component in _pipeline .components .items ():
1610+ if isinstance (component , torch .nn .Module ):
1611+ if hasattr (component , "_hf_hook" ):
1612+ is_model_cpu_offload = isinstance (getattr (component , "_hf_hook" ), CpuOffload )
1613+ is_sequential_cpu_offload = isinstance (
1614+ getattr (component , "_hf_hook" ), AlignDevicesHook
1615+ )
1616+ logger .info (
1617+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
1618+ )
1619+ remove_hook_from_module (component , recurse = is_sequential_cpu_offload )
1620+
15721621 text_encoder .to (device = text_encoder .device , dtype = text_encoder .dtype )
15731622
1623+ # Offload back.
1624+ if is_model_cpu_offload :
1625+ _pipeline .enable_model_cpu_offload ()
1626+ elif is_sequential_cpu_offload :
1627+ _pipeline .enable_sequential_cpu_offload ()
1628+ # Unsafe code />
1629+
15741630 @property
15751631 def lora_scale (self ) -> float :
15761632 # property function that returns the lora scale which can be set at run time by the pipeline.
@@ -2639,31 +2695,17 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
26392695 # it here explicitly to be able to tell that it's coming from an SDXL
26402696 # pipeline.
26412697
2642- # Remove any existing hooks.
2643- if is_accelerate_available () and is_accelerate_version (">=" , "0.17.0.dev0" ):
2644- from accelerate .hooks import AlignDevicesHook , CpuOffload , remove_hook_from_module
2645- else :
2646- raise ImportError ("Offloading requires `accelerate v0.17.0` or higher." )
2647-
2648- is_model_cpu_offload = False
2649- is_sequential_cpu_offload = False
2650- for _ , component in self .components .items ():
2651- if isinstance (component , torch .nn .Module ):
2652- if hasattr (component , "_hf_hook" ):
2653- is_model_cpu_offload = isinstance (getattr (component , "_hf_hook" ), CpuOffload )
2654- is_sequential_cpu_offload = isinstance (getattr (component , "_hf_hook" ), AlignDevicesHook )
2655- logger .info (
2656- "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
2657- )
2658- remove_hook_from_module (component , recurse = is_sequential_cpu_offload )
2659-
2698+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
26602699 state_dict , network_alphas = self .lora_state_dict (
26612700 pretrained_model_name_or_path_or_dict ,
26622701 unet_config = self .unet .config ,
26632702 ** kwargs ,
26642703 )
2665- self .load_lora_into_unet (state_dict , network_alphas = network_alphas , unet = self .unet )
2704+ is_correct_format = all ("lora" in key for key in state_dict .keys ())
2705+ if not is_correct_format :
2706+ raise ValueError ("Invalid LoRA checkpoint." )
26662707
2708+ self .load_lora_into_unet (state_dict , network_alphas = network_alphas , unet = self .unet , _pipeline = self )
26672709 text_encoder_state_dict = {k : v for k , v in state_dict .items () if "text_encoder." in k }
26682710 if len (text_encoder_state_dict ) > 0 :
26692711 self .load_lora_into_text_encoder (
@@ -2672,6 +2714,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
26722714 text_encoder = self .text_encoder ,
26732715 prefix = "text_encoder" ,
26742716 lora_scale = self .lora_scale ,
2717+ _pipeline = self ,
26752718 )
26762719
26772720 text_encoder_2_state_dict = {k : v for k , v in state_dict .items () if "text_encoder_2." in k }
@@ -2682,14 +2725,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
26822725 text_encoder = self .text_encoder_2 ,
26832726 prefix = "text_encoder_2" ,
26842727 lora_scale = self .lora_scale ,
2728+ _pipeline = self ,
26852729 )
26862730
2687- # Offload back.
2688- if is_model_cpu_offload :
2689- self .enable_model_cpu_offload ()
2690- elif is_sequential_cpu_offload :
2691- self .enable_sequential_cpu_offload ()
2692-
26932731 @classmethod
26942732 def save_lora_weights (
26952733 self ,
0 commit comments