Skip to content

Commit 49186b8

Browse files
Warlord-Khameerabbasi
authored andcommitted
Add AuraFlowLoraLoaderMixin
1 parent cef0e36 commit 49186b8

File tree

4 files changed

+360
-4
lines changed

4 files changed

+360
-4
lines changed

src/diffusers/loaders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def text_encoder_attn_modules(text_encoder):
6464
"AmusedLoraLoaderMixin",
6565
"StableDiffusionLoraLoaderMixin",
6666
"SD3LoraLoaderMixin",
67+
"AuraFlowLoraLoaderMixin",
6768
"StableDiffusionXLLoraLoaderMixin",
6869
"LoraLoaderMixin",
6970
"FluxLoraLoaderMixin",
@@ -91,6 +92,7 @@ def text_encoder_attn_modules(text_encoder):
9192
LoraLoaderMixin,
9293
Mochi1LoraLoaderMixin,
9394
SD3LoraLoaderMixin,
95+
AuraFlowLoraLoaderMixin,
9496
StableDiffusionLoraLoaderMixin,
9597
StableDiffusionXLLoraLoaderMixin,
9698
)

src/diffusers/loaders/lora_pipeline.py

Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1641,6 +1641,337 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t
16411641
super().unfuse_lora(components=components)
16421642

16431643

1644+
class AuraFlowLoraLoaderMixin(LoraBaseMixin):
1645+
r"""
1646+
Load LoRA layers into [`AuraFlowTransformer2DModel`]
1647+
Specific to [`AuraFlowPipeline`].
1648+
"""
1649+
1650+
_lora_loadable_modules = ["transformer"]
1651+
transformer_name = TRANSFORMER_NAME
1652+
text_encoder_name = TEXT_ENCODER_NAME
1653+
1654+
@classmethod
1655+
@validate_hf_hub_args
1656+
def lora_state_dict(
1657+
cls,
1658+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
1659+
**kwargs,
1660+
):
1661+
r"""
1662+
Return state dict for lora weights and the network alphas.
1663+
1664+
<Tip warning={true}>
1665+
1666+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
1667+
1668+
This function is experimental and might change in the future.
1669+
1670+
</Tip>
1671+
1672+
Parameters:
1673+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
1674+
Can be either:
1675+
1676+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1677+
the Hub.
1678+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1679+
with [`ModelMixin.save_pretrained`].
1680+
- A [torch state
1681+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
1682+
1683+
cache_dir (`Union[str, os.PathLike]`, *optional*):
1684+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1685+
is not used.
1686+
force_download (`bool`, *optional*, defaults to `False`):
1687+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1688+
cached versions if they exist.
1689+
1690+
proxies (`Dict[str, str]`, *optional*):
1691+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1692+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1693+
local_files_only (`bool`, *optional*, defaults to `False`):
1694+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
1695+
won't be downloaded from the Hub.
1696+
token (`str` or *bool*, *optional*):
1697+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1698+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
1699+
revision (`str`, *optional*, defaults to `"main"`):
1700+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1701+
allowed by Git.
1702+
subfolder (`str`, *optional*, defaults to `""`):
1703+
The subfolder location of a model file within a larger model repository on the Hub or locally.
1704+
1705+
"""
1706+
# Load the main state dict first which has the LoRA layers for either of
1707+
# transformer and text encoder or both.
1708+
cache_dir = kwargs.pop("cache_dir", None)
1709+
force_download = kwargs.pop("force_download", False)
1710+
proxies = kwargs.pop("proxies", None)
1711+
local_files_only = kwargs.pop("local_files_only", None)
1712+
token = kwargs.pop("token", None)
1713+
revision = kwargs.pop("revision", None)
1714+
subfolder = kwargs.pop("subfolder", None)
1715+
weight_name = kwargs.pop("weight_name", None)
1716+
use_safetensors = kwargs.pop("use_safetensors", None)
1717+
1718+
allow_pickle = False
1719+
if use_safetensors is None:
1720+
use_safetensors = True
1721+
allow_pickle = True
1722+
1723+
user_agent = {
1724+
"file_type": "attn_procs_weights",
1725+
"framework": "pytorch",
1726+
}
1727+
1728+
state_dict = cls._fetch_state_dict(
1729+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
1730+
weight_name=weight_name,
1731+
use_safetensors=use_safetensors,
1732+
local_files_only=local_files_only,
1733+
cache_dir=cache_dir,
1734+
force_download=force_download,
1735+
proxies=proxies,
1736+
token=token,
1737+
revision=revision,
1738+
subfolder=subfolder,
1739+
user_agent=user_agent,
1740+
allow_pickle=allow_pickle,
1741+
)
1742+
1743+
return state_dict
1744+
1745+
def load_lora_weights(
1746+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
1747+
):
1748+
"""
1749+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
1750+
1751+
All kwargs are forwarded to `self.lora_state_dict`.
1752+
1753+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
1754+
loaded.
1755+
1756+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
1757+
dict is loaded into `self.transformer`.
1758+
1759+
Parameters:
1760+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
1761+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1762+
kwargs (`dict`, *optional*):
1763+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1764+
adapter_name (`str`, *optional*):
1765+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1766+
`default_{i}` where i is the total number of adapters being loaded.
1767+
"""
1768+
if not USE_PEFT_BACKEND:
1769+
raise ValueError("PEFT backend is required for this method.")
1770+
1771+
# if a dict is passed, copy it instead of modifying it inplace
1772+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
1773+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
1774+
1775+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1776+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
1777+
1778+
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
1779+
if not is_correct_format:
1780+
raise ValueError("Invalid LoRA checkpoint.")
1781+
1782+
self.load_lora_into_transformer(
1783+
state_dict,
1784+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
1785+
adapter_name=adapter_name,
1786+
_pipeline=self,
1787+
)
1788+
1789+
1790+
@classmethod
1791+
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
1792+
"""
1793+
This will load the LoRA layers specified in `state_dict` into `transformer`.
1794+
1795+
Parameters:
1796+
state_dict (`dict`):
1797+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
1798+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
1799+
encoder lora layers.
1800+
transformer (`SD3Transformer2DModel`):
1801+
The Transformer model to load the LoRA layers into.
1802+
adapter_name (`str`, *optional*):
1803+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1804+
`default_{i}` where i is the total number of adapters being loaded.
1805+
"""
1806+
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
1807+
1808+
keys = list(state_dict.keys())
1809+
1810+
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
1811+
state_dict = {
1812+
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
1813+
}
1814+
1815+
if len(state_dict.keys()) > 0:
1816+
# check with first key if is not in peft format
1817+
first_key = next(iter(state_dict.keys()))
1818+
if "lora_A" not in first_key:
1819+
state_dict = convert_unet_state_dict_to_peft(state_dict)
1820+
1821+
if adapter_name in getattr(transformer, "peft_config", {}):
1822+
raise ValueError(
1823+
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
1824+
)
1825+
1826+
rank = {}
1827+
for key, val in state_dict.items():
1828+
if "lora_B" in key:
1829+
rank[key] = val.shape[1]
1830+
1831+
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
1832+
if "use_dora" in lora_config_kwargs:
1833+
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
1834+
raise ValueError(
1835+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1836+
)
1837+
else:
1838+
lora_config_kwargs.pop("use_dora")
1839+
lora_config = LoraConfig(**lora_config_kwargs)
1840+
1841+
# adapter_name
1842+
if adapter_name is None:
1843+
adapter_name = get_adapter_name(transformer)
1844+
1845+
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
1846+
# otherwise loading LoRA weights will lead to an error
1847+
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
1848+
1849+
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
1850+
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
1851+
1852+
if incompatible_keys is not None:
1853+
# check only for unexpected keys
1854+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1855+
if unexpected_keys:
1856+
logger.warning(
1857+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1858+
f" {unexpected_keys}. "
1859+
)
1860+
1861+
# Offload back.
1862+
if is_model_cpu_offload:
1863+
_pipeline.enable_model_cpu_offload()
1864+
elif is_sequential_cpu_offload:
1865+
_pipeline.enable_sequential_cpu_offload()
1866+
# Unsafe code />
1867+
1868+
@classmethod
1869+
def save_lora_weights(
1870+
cls,
1871+
save_directory: Union[str, os.PathLike],
1872+
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
1873+
is_main_process: bool = True,
1874+
weight_name: str = None,
1875+
save_function: Callable = None,
1876+
safe_serialization: bool = True,
1877+
):
1878+
r"""
1879+
Save the LoRA parameters corresponding to the UNet and text encoder.
1880+
1881+
Arguments:
1882+
save_directory (`str` or `os.PathLike`):
1883+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
1884+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1885+
State dict of the LoRA layers corresponding to the `transformer`.
1886+
is_main_process (`bool`, *optional*, defaults to `True`):
1887+
Whether the process calling this is the main process or not. Useful during distributed training and you
1888+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
1889+
process to avoid race conditions.
1890+
save_function (`Callable`):
1891+
The function to use to save the state dictionary. Useful during distributed training when you need to
1892+
replace `torch.save` with another method. Can be configured with the environment variable
1893+
`DIFFUSERS_SAVE_MODE`.
1894+
safe_serialization (`bool`, *optional*, defaults to `True`):
1895+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
1896+
"""
1897+
state_dict = {}
1898+
1899+
if not (transformer_lora_layers):
1900+
raise ValueError(
1901+
"You must pass `transformer_lora_layers`."
1902+
)
1903+
1904+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
1905+
1906+
# Save the model
1907+
cls.write_lora_layers(
1908+
state_dict=state_dict,
1909+
save_directory=save_directory,
1910+
is_main_process=is_main_process,
1911+
weight_name=weight_name,
1912+
save_function=save_function,
1913+
safe_serialization=safe_serialization,
1914+
)
1915+
1916+
def fuse_lora(
1917+
self,
1918+
components: List[str] = ["transformer"],
1919+
lora_scale: float = 1.0,
1920+
safe_fusing: bool = False,
1921+
adapter_names: Optional[List[str]] = None,
1922+
**kwargs,
1923+
):
1924+
r"""
1925+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
1926+
1927+
<Tip warning={true}>
1928+
1929+
This is an experimental API.
1930+
1931+
</Tip>
1932+
1933+
Args:
1934+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
1935+
lora_scale (`float`, defaults to 1.0):
1936+
Controls how much to influence the outputs with the LoRA parameters.
1937+
safe_fusing (`bool`, defaults to `False`):
1938+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
1939+
adapter_names (`List[str]`, *optional*):
1940+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
1941+
1942+
Example:
1943+
1944+
```py
1945+
from diffusers import DiffusionPipeline
1946+
import torch
1947+
1948+
pipeline = DiffusionPipeline.from_pretrained(
1949+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
1950+
).to("cuda")
1951+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
1952+
pipeline.fuse_lora(lora_scale=0.7)
1953+
```
1954+
"""
1955+
super().fuse_lora(
1956+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
1957+
)
1958+
1959+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
1960+
r"""
1961+
Reverses the effect of
1962+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
1963+
1964+
<Tip warning={true}>
1965+
1966+
This is an experimental API.
1967+
1968+
</Tip>
1969+
1970+
Args:
1971+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
1972+
"""
1973+
super().unfuse_lora(components=components)
1974+
16441975
class FluxLoraLoaderMixin(LoraBaseMixin):
16451976
r"""
16461977
Load LoRA layers into [`FluxTransformer2DModel`],
@@ -1649,6 +1980,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
16491980
Specific to [`StableDiffusion3Pipeline`].
16501981
"""
16511982

1983+
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
1984+
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
1985+
class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
16521986
_lora_loadable_modules = ["transformer", "text_encoder"]
16531987
transformer_name = TRANSFORMER_NAME
16541988
text_encoder_name = TEXT_ENCODER_NAME

0 commit comments

Comments
 (0)