@@ -424,6 +424,17 @@ def _load_lora_into_text_encoder(
424424
425425
426426def _func_optionally_disable_offloading (_pipeline ):
427+ """
428+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
429+
430+ Args:
431+ _pipeline (`DiffusionPipeline`):
432+ The pipeline to disable offloading for.
433+
434+ Returns:
435+ tuple:
436+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
437+ """
427438 is_model_cpu_offload = False
428439 is_sequential_cpu_offload = False
429440
@@ -442,7 +453,8 @@ def _func_optionally_disable_offloading(_pipeline):
442453 logger .info (
443454 "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
444455 )
445- remove_hook_from_module (component , recurse = is_sequential_cpu_offload )
456+ if is_sequential_cpu_offload or is_model_cpu_offload :
457+ remove_hook_from_module (component , recurse = is_sequential_cpu_offload )
446458
447459 return (is_model_cpu_offload , is_sequential_cpu_offload )
448460
@@ -453,6 +465,24 @@ class LoraBaseMixin:
453465 _lora_loadable_modules = []
454466 _merged_adapters = set ()
455467
468+ @property
469+ def lora_scale (self ) -> float :
470+ """
471+ Returns the lora scale which can be set at run time by the pipeline. # if `_lora_scale` has not been set,
472+ return 1.
473+ """
474+ return self ._lora_scale if hasattr (self , "_lora_scale" ) else 1.0
475+
476+ @property
477+ def num_fused_loras (self ):
478+ """Returns the number of LoRAs that have been fused."""
479+ return len (self ._merged_adapters )
480+
481+ @property
482+ def fused_loras (self ):
483+ """Returns names of the LoRAs that have been fused."""
484+ return self ._merged_adapters
485+
456486 def load_lora_weights (self , ** kwargs ):
457487 raise NotImplementedError ("`load_lora_weights()` is not implemented." )
458488
@@ -464,33 +494,6 @@ def save_lora_weights(cls, **kwargs):
464494 def lora_state_dict (cls , ** kwargs ):
465495 raise NotImplementedError ("`lora_state_dict()` is not implemented." )
466496
467- @classmethod
468- def _optionally_disable_offloading (cls , _pipeline ):
469- """
470- Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
471-
472- Args:
473- _pipeline (`DiffusionPipeline`):
474- The pipeline to disable offloading for.
475-
476- Returns:
477- tuple:
478- A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
479- """
480- return _func_optionally_disable_offloading (_pipeline = _pipeline )
481-
482- @classmethod
483- def _fetch_state_dict (cls , * args , ** kwargs ):
484- deprecation_message = f"Using the `_fetch_state_dict()` method from { cls } has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
485- deprecate ("_fetch_state_dict" , "0.35.0" , deprecation_message )
486- return _fetch_state_dict (* args , ** kwargs )
487-
488- @classmethod
489- def _best_guess_weight_name (cls , * args , ** kwargs ):
490- deprecation_message = f"Using the `_best_guess_weight_name()` method from { cls } has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
491- deprecate ("_best_guess_weight_name" , "0.35.0" , deprecation_message )
492- return _best_guess_weight_name (* args , ** kwargs )
493-
494497 def unload_lora_weights (self ):
495498 """
496499 Unloads the LoRA parameters.
@@ -661,19 +664,37 @@ def unfuse_lora(self, components: List[str] = [], **kwargs):
661664 self ._merged_adapters = self ._merged_adapters - {adapter }
662665 module .unmerge ()
663666
664- @property
665- def num_fused_loras (self ):
666- return len (self ._merged_adapters )
667-
668- @property
669- def fused_loras (self ):
670- return self ._merged_adapters
671-
672667 def set_adapters (
673668 self ,
674669 adapter_names : Union [List [str ], str ],
675670 adapter_weights : Optional [Union [float , Dict , List [float ], List [Dict ]]] = None ,
676671 ):
672+ """
673+ Set the currently active adapters for use in the pipeline.
674+
675+ Args:
676+ adapter_names (`List[str]` or `str`):
677+ The names of the adapters to use.
678+ adapter_weights (`Union[List[float], float]`, *optional*):
679+ The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
680+ adapters.
681+
682+ Example:
683+
684+ ```py
685+ from diffusers import AutoPipelineForText2Image
686+ import torch
687+
688+ pipeline = AutoPipelineForText2Image.from_pretrained(
689+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
690+ ).to("cuda")
691+ pipeline.load_lora_weights(
692+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
693+ )
694+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
695+ pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
696+ ```
697+ """
677698 if isinstance (adapter_weights , dict ):
678699 components_passed = set (adapter_weights .keys ())
679700 lora_components = set (self ._lora_loadable_modules )
@@ -743,6 +764,24 @@ def set_adapters(
743764 set_adapters_for_text_encoder (adapter_names , model , _component_adapter_weights [component ])
744765
745766 def disable_lora (self ):
767+ """
768+ Disables the active LoRA layers of the pipeline.
769+
770+ Example:
771+
772+ ```py
773+ from diffusers import AutoPipelineForText2Image
774+ import torch
775+
776+ pipeline = AutoPipelineForText2Image.from_pretrained(
777+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
778+ ).to("cuda")
779+ pipeline.load_lora_weights(
780+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
781+ )
782+ pipeline.disable_lora()
783+ ```
784+ """
746785 if not USE_PEFT_BACKEND :
747786 raise ValueError ("PEFT backend is required for this method." )
748787
@@ -755,6 +794,24 @@ def disable_lora(self):
755794 disable_lora_for_text_encoder (model )
756795
757796 def enable_lora (self ):
797+ """
798+ Enables the active LoRA layers of the pipeline.
799+
800+ Example:
801+
802+ ```py
803+ from diffusers import AutoPipelineForText2Image
804+ import torch
805+
806+ pipeline = AutoPipelineForText2Image.from_pretrained(
807+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
808+ ).to("cuda")
809+ pipeline.load_lora_weights(
810+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
811+ )
812+ pipeline.enable_lora()
813+ ```
814+ """
758815 if not USE_PEFT_BACKEND :
759816 raise ValueError ("PEFT backend is required for this method." )
760817
@@ -768,10 +825,26 @@ def enable_lora(self):
768825
769826 def delete_adapters (self , adapter_names : Union [List [str ], str ]):
770827 """
828+ Delete an adapter's LoRA layers from the pipeline.
829+
771830 Args:
772- Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
773831 adapter_names (`Union[List[str], str]`):
774- The names of the adapter to delete. Can be a single string or a list of strings
832+ The names of the adapters to delete.
833+
834+ Example:
835+
836+ ```py
837+ from diffusers import AutoPipelineForText2Image
838+ import torch
839+
840+ pipeline = AutoPipelineForText2Image.from_pretrained(
841+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
842+ ).to("cuda")
843+ pipeline.load_lora_weights(
844+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
845+ )
846+ pipeline.delete_adapters("cinematic")
847+ ```
775848 """
776849 if not USE_PEFT_BACKEND :
777850 raise ValueError ("PEFT backend is required for this method." )
@@ -872,6 +945,24 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
872945 adapter_name
873946 ].to (device )
874947
948+ def enable_lora_hotswap (self , ** kwargs ) -> None :
949+ """
950+ Hotswap adapters without triggering recompilation of a model or if the ranks of the loaded adapters are
951+ different.
952+
953+ Args:
954+ target_rank (`int`):
955+ The highest rank among all the adapters that will be loaded.
956+ check_compiled (`str`, *optional*, defaults to `"error"`):
957+ How to handle a model that is already compiled. The check can return the following messages:
958+ - "error" (default): raise an error
959+ - "warn": issue a warning
960+ - "ignore": do nothing
961+ """
962+ for key , component in self .components .items ():
963+ if hasattr (component , "enable_lora_hotswap" ) and (key in self ._lora_loadable_modules ):
964+ component .enable_lora_hotswap (** kwargs )
965+
875966 @staticmethod
876967 def pack_weights (layers , prefix ):
877968 layers_weights = layers .state_dict () if isinstance (layers , torch .nn .Module ) else layers
@@ -887,6 +978,7 @@ def write_lora_layers(
887978 safe_serialization : bool ,
888979 lora_adapter_metadata : Optional [dict ] = None ,
889980 ):
981+ """Writes the state dict of the LoRA layers (optionally with metadata) to disk."""
890982 if os .path .isfile (save_directory ):
891983 logger .error (f"Provided path ({ save_directory } ) should be a directory, not a file" )
892984 return
@@ -927,28 +1019,18 @@ def save_function(weights, filename):
9271019 save_function (state_dict , save_path )
9281020 logger .info (f"Model weights saved in { save_path } " )
9291021
930- @property
931- def lora_scale (self ) -> float :
932- # property function that returns the lora scale which can be set at run time by the pipeline.
933- # if _lora_scale has not been set, return 1
934- return self ._lora_scale if hasattr (self , "_lora_scale" ) else 1.0
935-
936- def enable_lora_hotswap (self , ** kwargs ) -> None :
937- """Enables the possibility to hotswap LoRA adapters.
1022+ @classmethod
1023+ def _optionally_disable_offloading (cls , _pipeline ):
1024+ return _func_optionally_disable_offloading (_pipeline = _pipeline )
9381025
939- Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
940- the loaded adapters differ.
1026+ @classmethod
1027+ def _fetch_state_dict (cls , * args , ** kwargs ):
1028+ deprecation_message = f"Using the `_fetch_state_dict()` method from { cls } has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
1029+ deprecate ("_fetch_state_dict" , "0.35.0" , deprecation_message )
1030+ return _fetch_state_dict (* args , ** kwargs )
9411031
942- Args:
943- target_rank (`int`):
944- The highest rank among all the adapters that will be loaded.
945- check_compiled (`str`, *optional*, defaults to `"error"`):
946- How to handle the case when the model is already compiled, which should generally be avoided. The
947- options are:
948- - "error" (default): raise an error
949- - "warn": issue a warning
950- - "ignore": do nothing
951- """
952- for key , component in self .components .items ():
953- if hasattr (component , "enable_lora_hotswap" ) and (key in self ._lora_loadable_modules ):
954- component .enable_lora_hotswap (** kwargs )
1032+ @classmethod
1033+ def _best_guess_weight_name (cls , * args , ** kwargs ):
1034+ deprecation_message = f"Using the `_best_guess_weight_name()` method from { cls } has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
1035+ deprecate ("_best_guess_weight_name" , "0.35.0" , deprecation_message )
1036+ return _best_guess_weight_name (* args , ** kwargs )
0 commit comments