Skip to content

Commit edea577

Browse files
patrickvonplatenmultimodalart
andcommitted
[Lora] fix lora fuse unfuse (#5003)
* fix lora fuse unfuse * add same changes to loaders.py * add test --------- Co-authored-by: multimodalart <[email protected]>
1 parent c37c840 commit edea577

File tree

3 files changed

+42
-4
lines changed

3 files changed

+42
-4
lines changed

src/diffusers/loaders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _fuse_lora(self, lora_scale=1.0):
119119
self.lora_scale = lora_scale
120120

121121
def _unfuse_lora(self):
122-
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
122+
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
123123
return
124124

125125
fused_weight = self.regular_linear_layer.weight.data

src/diffusers/models/lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def _fuse_lora(self, lora_scale=1.0):
139139
self._lora_scale = lora_scale
140140

141141
def _unfuse_lora(self):
142-
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
142+
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
143143
return
144144

145145
fused_weight = self.weight.data
@@ -204,7 +204,7 @@ def _fuse_lora(self, lora_scale=1.0):
204204
self._lora_scale = lora_scale
205205

206206
def _unfuse_lora(self):
207-
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
207+
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
208208
return
209209

210210
fused_weight = self.weight.data

tests/models/test_lora_layers.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
LoRAAttnProcessor2_0,
4444
XFormersAttnProcessor,
4545
)
46-
from diffusers.utils.testing_utils import floats_tensor, require_torch_gpu, slow, torch_device
46+
from diffusers.utils.testing_utils import floats_tensor, nightly, require_torch_gpu, slow, torch_device
4747

4848

4949
def create_unet_lora_layers(unet: nn.Module):
@@ -1464,3 +1464,41 @@ def test_sdxl_1_0_lora_with_sequential_cpu_offloading(self):
14641464
expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535])
14651465

14661466
self.assertTrue(np.allclose(images, expected, atol=1e-3))
1467+
1468+
@nightly
1469+
def test_sequential_fuse_unfuse(self):
1470+
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
1471+
1472+
# 1. round
1473+
pipe.load_lora_weights("Pclanglais/TintinIA")
1474+
pipe.fuse_lora()
1475+
1476+
generator = torch.Generator().manual_seed(0)
1477+
images = pipe(
1478+
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
1479+
).images
1480+
image_slice = images[0, -3:, -3:, -1].flatten()
1481+
1482+
pipe.unfuse_lora()
1483+
1484+
# 2. round
1485+
pipe.load_lora_weights("ProomptEngineer/pe-balloon-diffusion-style")
1486+
pipe.fuse_lora()
1487+
pipe.unfuse_lora()
1488+
1489+
# 3. round
1490+
pipe.load_lora_weights("ostris/crayon_style_lora_sdxl")
1491+
pipe.fuse_lora()
1492+
pipe.unfuse_lora()
1493+
1494+
# 4. back to 1st round
1495+
pipe.load_lora_weights("Pclanglais/TintinIA")
1496+
pipe.fuse_lora()
1497+
1498+
generator = torch.Generator().manual_seed(0)
1499+
images_2 = pipe(
1500+
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
1501+
).images
1502+
image_slice_2 = images_2[0, -3:, -3:, -1].flatten()
1503+
1504+
self.assertTrue(np.allclose(image_slice, image_slice_2, atol=1e-3))

0 commit comments

Comments
 (0)