Skip to content

Commit 668e34c

Browse files
authored
[LoRA SD3] add support for lora fusion in sd3 (#8616)
* add support for lora fusion in sd3 * add test to ensure fused lora and effective lora produce same outpouts
1 parent 25d7bb3 commit 668e34c

File tree

3 files changed

+200
-0
lines changed

3 files changed

+200
-0
lines changed

src/diffusers/loaders/lora.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1728,3 +1728,78 @@ def _optionally_disable_offloading(cls, _pipeline):
17281728
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
17291729

17301730
return (is_model_cpu_offload, is_sequential_cpu_offload)
1731+
1732+
def fuse_lora(
1733+
self,
1734+
fuse_transformer: bool = True,
1735+
lora_scale: float = 1.0,
1736+
safe_fusing: bool = False,
1737+
adapter_names: Optional[List[str]] = None,
1738+
):
1739+
r"""
1740+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
1741+
1742+
<Tip warning={true}>
1743+
1744+
This is an experimental API.
1745+
1746+
</Tip>
1747+
1748+
Args:
1749+
fuse_transformer (`bool`, defaults to `True`): Whether to fuse the transformer LoRA parameters.
1750+
lora_scale (`float`, defaults to 1.0):
1751+
Controls how much to influence the outputs with the LoRA parameters.
1752+
safe_fusing (`bool`, defaults to `False`):
1753+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
1754+
adapter_names (`List[str]`, *optional*):
1755+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
1756+
1757+
Example:
1758+
1759+
```py
1760+
from diffusers import DiffusionPipeline
1761+
import torch
1762+
1763+
pipeline = DiffusionPipeline.from_pretrained(
1764+
"stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
1765+
).to("cuda")
1766+
pipeline.load_lora_weights(
1767+
"nerijs/pixel-art-medium-128-v0.1",
1768+
weight_name="pixel-art-medium-128-v0.1.safetensors",
1769+
adapter_name="pixel",
1770+
)
1771+
pipeline.fuse_lora(lora_scale=0.7)
1772+
```
1773+
"""
1774+
if fuse_transformer:
1775+
self.num_fused_loras += 1
1776+
1777+
if fuse_transformer:
1778+
transformer = (
1779+
getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
1780+
)
1781+
transformer.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
1782+
1783+
def unfuse_lora(self, unfuse_transformer: bool = True):
1784+
r"""
1785+
Reverses the effect of
1786+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora).
1787+
1788+
<Tip warning={true}>
1789+
1790+
This is an experimental API.
1791+
1792+
</Tip>
1793+
1794+
Args:
1795+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the transformer LoRA parameters.
1796+
"""
1797+
from peft.tuners.tuners_utils import BaseTunerLayer
1798+
1799+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
1800+
if unfuse_transformer:
1801+
for module in transformer.modules():
1802+
if isinstance(module, BaseTunerLayer):
1803+
module.unmerge()
1804+
1805+
self.num_fused_loras -= 1

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515

16+
import inspect
17+
from functools import partial
1618
from typing import Any, Dict, List, Optional, Union
1719

1820
import torch
@@ -239,6 +241,47 @@ def _set_gradient_checkpointing(self, module, value=False):
239241
if hasattr(module, "gradient_checkpointing"):
240242
module.gradient_checkpointing = value
241243

244+
def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
245+
if not USE_PEFT_BACKEND:
246+
raise ValueError("PEFT backend is required for `fuse_lora()`.")
247+
248+
self.lora_scale = lora_scale
249+
self._safe_fusing = safe_fusing
250+
self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
251+
252+
def _fuse_lora_apply(self, module, adapter_names=None):
253+
from peft.tuners.tuners_utils import BaseTunerLayer
254+
255+
merge_kwargs = {"safe_merge": self._safe_fusing}
256+
257+
if isinstance(module, BaseTunerLayer):
258+
if self.lora_scale != 1.0:
259+
module.scale_layer(self.lora_scale)
260+
261+
# For BC with prevous PEFT versions, we need to check the signature
262+
# of the `merge` method to see if it supports the `adapter_names` argument.
263+
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
264+
if "adapter_names" in supported_merge_kwargs:
265+
merge_kwargs["adapter_names"] = adapter_names
266+
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
267+
raise ValueError(
268+
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
269+
" to the latest version of PEFT. `pip install -U peft`"
270+
)
271+
272+
module.merge(**merge_kwargs)
273+
274+
def unfuse_lora(self):
275+
if not USE_PEFT_BACKEND:
276+
raise ValueError("PEFT backend is required for `unfuse_lora()`.")
277+
self.apply(self._unfuse_lora_apply)
278+
279+
def _unfuse_lora_apply(self, module):
280+
from peft.tuners.tuners_utils import BaseTunerLayer
281+
282+
if isinstance(module, BaseTunerLayer):
283+
module.unmerge()
284+
242285
def forward(
243286
self,
244287
hidden_states: torch.FloatTensor,

tests/lora/test_lora_layers_sd3.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,85 @@ def test_simple_inference_with_transformer_lora_and_scale(self):
205205
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
206206
"Lora + 0 scale should lead to same result as no LoRA",
207207
)
208+
209+
def test_simple_inference_with_transformer_fused(self):
210+
components = self.get_dummy_components()
211+
transformer_lora_config = self.get_lora_config_for_transformer()
212+
pipe = self.pipeline_class(**components)
213+
pipe = pipe.to(torch_device)
214+
pipe.set_progress_bar_config(disable=None)
215+
216+
inputs = self.get_dummy_inputs(torch_device)
217+
output_no_lora = pipe(**inputs).images
218+
219+
pipe.transformer.add_adapter(transformer_lora_config)
220+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
221+
222+
pipe.fuse_lora()
223+
# Fusing should still keep the LoRA layers
224+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
225+
226+
inputs = self.get_dummy_inputs(torch_device)
227+
ouput_fused = pipe(**inputs).images
228+
self.assertFalse(
229+
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
230+
)
231+
232+
def test_simple_inference_with_transformer_fused_with_no_fusion(self):
233+
components = self.get_dummy_components()
234+
transformer_lora_config = self.get_lora_config_for_transformer()
235+
pipe = self.pipeline_class(**components)
236+
pipe = pipe.to(torch_device)
237+
pipe.set_progress_bar_config(disable=None)
238+
239+
inputs = self.get_dummy_inputs(torch_device)
240+
output_no_lora = pipe(**inputs).images
241+
242+
pipe.transformer.add_adapter(transformer_lora_config)
243+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
244+
inputs = self.get_dummy_inputs(torch_device)
245+
ouput_lora = pipe(**inputs).images
246+
247+
pipe.fuse_lora()
248+
# Fusing should still keep the LoRA layers
249+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
250+
251+
inputs = self.get_dummy_inputs(torch_device)
252+
ouput_fused = pipe(**inputs).images
253+
self.assertFalse(
254+
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
255+
)
256+
self.assertTrue(
257+
np.allclose(ouput_fused, ouput_lora, atol=1e-3, rtol=1e-3),
258+
"Fused lora output should be changed when LoRA isn't fused but still effective.",
259+
)
260+
261+
def test_simple_inference_with_transformer_fuse_unfuse(self):
262+
components = self.get_dummy_components()
263+
transformer_lora_config = self.get_lora_config_for_transformer()
264+
pipe = self.pipeline_class(**components)
265+
pipe = pipe.to(torch_device)
266+
pipe.set_progress_bar_config(disable=None)
267+
268+
inputs = self.get_dummy_inputs(torch_device)
269+
output_no_lora = pipe(**inputs).images
270+
271+
pipe.transformer.add_adapter(transformer_lora_config)
272+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
273+
274+
pipe.fuse_lora()
275+
# Fusing should still keep the LoRA layers
276+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
277+
inputs = self.get_dummy_inputs(torch_device)
278+
ouput_fused = pipe(**inputs).images
279+
self.assertFalse(
280+
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
281+
)
282+
283+
pipe.unfuse_lora()
284+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
285+
inputs = self.get_dummy_inputs(torch_device)
286+
output_unfused_lora = pipe(**inputs).images
287+
self.assertTrue(
288+
np.allclose(ouput_fused, output_unfused_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
289+
)

0 commit comments

Comments
 (0)