@@ -1641,6 +1641,337 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t
16411641 super ().unfuse_lora (components = components )
16421642
16431643
1644+ class AuraFlowLoraLoaderMixin (LoraBaseMixin ):
1645+ r"""
1646+ Load LoRA layers into [`AuraFlowTransformer2DModel`]
1647+ Specific to [`AuraFlowPipeline`].
1648+ """
1649+
1650+ _lora_loadable_modules = ["transformer" ]
1651+ transformer_name = TRANSFORMER_NAME
1652+ text_encoder_name = TEXT_ENCODER_NAME
1653+
1654+ @classmethod
1655+ @validate_hf_hub_args
1656+ def lora_state_dict (
1657+ cls ,
1658+ pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
1659+ ** kwargs ,
1660+ ):
1661+ r"""
1662+ Return state dict for lora weights and the network alphas.
1663+
1664+ <Tip warning={true}>
1665+
1666+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
1667+
1668+ This function is experimental and might change in the future.
1669+
1670+ </Tip>
1671+
1672+ Parameters:
1673+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
1674+ Can be either:
1675+
1676+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1677+ the Hub.
1678+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1679+ with [`ModelMixin.save_pretrained`].
1680+ - A [torch state
1681+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
1682+
1683+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1684+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1685+ is not used.
1686+ force_download (`bool`, *optional*, defaults to `False`):
1687+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1688+ cached versions if they exist.
1689+
1690+ proxies (`Dict[str, str]`, *optional*):
1691+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1692+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1693+ local_files_only (`bool`, *optional*, defaults to `False`):
1694+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1695+ won't be downloaded from the Hub.
1696+ token (`str` or *bool*, *optional*):
1697+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1698+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1699+ revision (`str`, *optional*, defaults to `"main"`):
1700+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1701+ allowed by Git.
1702+ subfolder (`str`, *optional*, defaults to `""`):
1703+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1704+
1705+ """
1706+ # Load the main state dict first which has the LoRA layers for either of
1707+ # transformer and text encoder or both.
1708+ cache_dir = kwargs .pop ("cache_dir" , None )
1709+ force_download = kwargs .pop ("force_download" , False )
1710+ proxies = kwargs .pop ("proxies" , None )
1711+ local_files_only = kwargs .pop ("local_files_only" , None )
1712+ token = kwargs .pop ("token" , None )
1713+ revision = kwargs .pop ("revision" , None )
1714+ subfolder = kwargs .pop ("subfolder" , None )
1715+ weight_name = kwargs .pop ("weight_name" , None )
1716+ use_safetensors = kwargs .pop ("use_safetensors" , None )
1717+
1718+ allow_pickle = False
1719+ if use_safetensors is None :
1720+ use_safetensors = True
1721+ allow_pickle = True
1722+
1723+ user_agent = {
1724+ "file_type" : "attn_procs_weights" ,
1725+ "framework" : "pytorch" ,
1726+ }
1727+
1728+ state_dict = cls ._fetch_state_dict (
1729+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
1730+ weight_name = weight_name ,
1731+ use_safetensors = use_safetensors ,
1732+ local_files_only = local_files_only ,
1733+ cache_dir = cache_dir ,
1734+ force_download = force_download ,
1735+ proxies = proxies ,
1736+ token = token ,
1737+ revision = revision ,
1738+ subfolder = subfolder ,
1739+ user_agent = user_agent ,
1740+ allow_pickle = allow_pickle ,
1741+ )
1742+
1743+ return state_dict
1744+
1745+ def load_lora_weights (
1746+ self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
1747+ ):
1748+ """
1749+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
1750+
1751+ All kwargs are forwarded to `self.lora_state_dict`.
1752+
1753+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
1754+ loaded.
1755+
1756+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
1757+ dict is loaded into `self.transformer`.
1758+
1759+ Parameters:
1760+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
1761+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1762+ kwargs (`dict`, *optional*):
1763+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1764+ adapter_name (`str`, *optional*):
1765+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1766+ `default_{i}` where i is the total number of adapters being loaded.
1767+ """
1768+ if not USE_PEFT_BACKEND :
1769+ raise ValueError ("PEFT backend is required for this method." )
1770+
1771+ # if a dict is passed, copy it instead of modifying it inplace
1772+ if isinstance (pretrained_model_name_or_path_or_dict , dict ):
1773+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
1774+
1775+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1776+ state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
1777+
1778+ is_correct_format = all ("lora" in key or "dora_scale" in key for key in state_dict .keys ())
1779+ if not is_correct_format :
1780+ raise ValueError ("Invalid LoRA checkpoint." )
1781+
1782+ self .load_lora_into_transformer (
1783+ state_dict ,
1784+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
1785+ adapter_name = adapter_name ,
1786+ _pipeline = self ,
1787+ )
1788+
1789+
1790+ @classmethod
1791+ def load_lora_into_transformer (cls , state_dict , transformer , adapter_name = None , _pipeline = None ):
1792+ """
1793+ This will load the LoRA layers specified in `state_dict` into `transformer`.
1794+
1795+ Parameters:
1796+ state_dict (`dict`):
1797+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
1798+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
1799+ encoder lora layers.
1800+ transformer (`SD3Transformer2DModel`):
1801+ The Transformer model to load the LoRA layers into.
1802+ adapter_name (`str`, *optional*):
1803+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1804+ `default_{i}` where i is the total number of adapters being loaded.
1805+ """
1806+ from peft import LoraConfig , inject_adapter_in_model , set_peft_model_state_dict
1807+
1808+ keys = list (state_dict .keys ())
1809+
1810+ transformer_keys = [k for k in keys if k .startswith (cls .transformer_name )]
1811+ state_dict = {
1812+ k .replace (f"{ cls .transformer_name } ." , "" ): v for k , v in state_dict .items () if k in transformer_keys
1813+ }
1814+
1815+ if len (state_dict .keys ()) > 0 :
1816+ # check with first key if is not in peft format
1817+ first_key = next (iter (state_dict .keys ()))
1818+ if "lora_A" not in first_key :
1819+ state_dict = convert_unet_state_dict_to_peft (state_dict )
1820+
1821+ if adapter_name in getattr (transformer , "peft_config" , {}):
1822+ raise ValueError (
1823+ f"Adapter name { adapter_name } already in use in the transformer - please select a new adapter name."
1824+ )
1825+
1826+ rank = {}
1827+ for key , val in state_dict .items ():
1828+ if "lora_B" in key :
1829+ rank [key ] = val .shape [1 ]
1830+
1831+ lora_config_kwargs = get_peft_kwargs (rank , network_alpha_dict = None , peft_state_dict = state_dict )
1832+ if "use_dora" in lora_config_kwargs :
1833+ if lora_config_kwargs ["use_dora" ] and is_peft_version ("<" , "0.9.0" ):
1834+ raise ValueError (
1835+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1836+ )
1837+ else :
1838+ lora_config_kwargs .pop ("use_dora" )
1839+ lora_config = LoraConfig (** lora_config_kwargs )
1840+
1841+ # adapter_name
1842+ if adapter_name is None :
1843+ adapter_name = get_adapter_name (transformer )
1844+
1845+ # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
1846+ # otherwise loading LoRA weights will lead to an error
1847+ is_model_cpu_offload , is_sequential_cpu_offload = cls ._optionally_disable_offloading (_pipeline )
1848+
1849+ inject_adapter_in_model (lora_config , transformer , adapter_name = adapter_name )
1850+ incompatible_keys = set_peft_model_state_dict (transformer , state_dict , adapter_name )
1851+
1852+ if incompatible_keys is not None :
1853+ # check only for unexpected keys
1854+ unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
1855+ if unexpected_keys :
1856+ logger .warning (
1857+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1858+ f" { unexpected_keys } . "
1859+ )
1860+
1861+ # Offload back.
1862+ if is_model_cpu_offload :
1863+ _pipeline .enable_model_cpu_offload ()
1864+ elif is_sequential_cpu_offload :
1865+ _pipeline .enable_sequential_cpu_offload ()
1866+ # Unsafe code />
1867+
1868+ @classmethod
1869+ def save_lora_weights (
1870+ cls ,
1871+ save_directory : Union [str , os .PathLike ],
1872+ transformer_lora_layers : Dict [str , torch .nn .Module ] = None ,
1873+ is_main_process : bool = True ,
1874+ weight_name : str = None ,
1875+ save_function : Callable = None ,
1876+ safe_serialization : bool = True ,
1877+ ):
1878+ r"""
1879+ Save the LoRA parameters corresponding to the UNet and text encoder.
1880+
1881+ Arguments:
1882+ save_directory (`str` or `os.PathLike`):
1883+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
1884+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1885+ State dict of the LoRA layers corresponding to the `transformer`.
1886+ is_main_process (`bool`, *optional*, defaults to `True`):
1887+ Whether the process calling this is the main process or not. Useful during distributed training and you
1888+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
1889+ process to avoid race conditions.
1890+ save_function (`Callable`):
1891+ The function to use to save the state dictionary. Useful during distributed training when you need to
1892+ replace `torch.save` with another method. Can be configured with the environment variable
1893+ `DIFFUSERS_SAVE_MODE`.
1894+ safe_serialization (`bool`, *optional*, defaults to `True`):
1895+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
1896+ """
1897+ state_dict = {}
1898+
1899+ if not (transformer_lora_layers ):
1900+ raise ValueError (
1901+ "You must pass `transformer_lora_layers`."
1902+ )
1903+
1904+ state_dict .update (cls .pack_weights (transformer_lora_layers , cls .transformer_name ))
1905+
1906+ # Save the model
1907+ cls .write_lora_layers (
1908+ state_dict = state_dict ,
1909+ save_directory = save_directory ,
1910+ is_main_process = is_main_process ,
1911+ weight_name = weight_name ,
1912+ save_function = save_function ,
1913+ safe_serialization = safe_serialization ,
1914+ )
1915+
1916+ def fuse_lora (
1917+ self ,
1918+ components : List [str ] = ["transformer" ],
1919+ lora_scale : float = 1.0 ,
1920+ safe_fusing : bool = False ,
1921+ adapter_names : Optional [List [str ]] = None ,
1922+ ** kwargs ,
1923+ ):
1924+ r"""
1925+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
1926+
1927+ <Tip warning={true}>
1928+
1929+ This is an experimental API.
1930+
1931+ </Tip>
1932+
1933+ Args:
1934+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
1935+ lora_scale (`float`, defaults to 1.0):
1936+ Controls how much to influence the outputs with the LoRA parameters.
1937+ safe_fusing (`bool`, defaults to `False`):
1938+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
1939+ adapter_names (`List[str]`, *optional*):
1940+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
1941+
1942+ Example:
1943+
1944+ ```py
1945+ from diffusers import DiffusionPipeline
1946+ import torch
1947+
1948+ pipeline = DiffusionPipeline.from_pretrained(
1949+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
1950+ ).to("cuda")
1951+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
1952+ pipeline.fuse_lora(lora_scale=0.7)
1953+ ```
1954+ """
1955+ super ().fuse_lora (
1956+ components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
1957+ )
1958+
1959+ def unfuse_lora (self , components : List [str ] = ["transformer" ], ** kwargs ):
1960+ r"""
1961+ Reverses the effect of
1962+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
1963+
1964+ <Tip warning={true}>
1965+
1966+ This is an experimental API.
1967+
1968+ </Tip>
1969+
1970+ Args:
1971+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
1972+ """
1973+ super ().unfuse_lora (components = components )
1974+
16441975class FluxLoraLoaderMixin (LoraBaseMixin ):
16451976 r"""
16461977 Load LoRA layers into [`FluxTransformer2DModel`],
@@ -1649,6 +1980,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
16491980 Specific to [`StableDiffusion3Pipeline`].
16501981 """
16511982
1983+ # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
1984+ # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
1985+ class AmusedLoraLoaderMixin (StableDiffusionLoraLoaderMixin ):
16521986 _lora_loadable_modules = ["transformer" , "text_encoder" ]
16531987 transformer_name = TRANSFORMER_NAME
16541988 text_encoder_name = TEXT_ENCODER_NAME
0 commit comments