Skip to content

Commit 0371cee

Browse files
authored
Merge branch 'main' into lora/fix-wan
2 parents 763fd3b + 2527917 commit 0371cee

File tree

2 files changed

+74
-2
lines changed

2 files changed

+74
-2
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,27 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
934934
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
935935
you want to load multiple adapters and free some GPU memory.
936936
937+
After offloading the LoRA adapters to CPU, as long as the rest of the model is still on GPU, the LoRA adapters
938+
can no longer be used for inference, as that would cause a device mismatch. Remember to set the device back to
939+
GPU before using those LoRA adapters for inference.
940+
941+
```python
942+
>>> pipe.load_lora_weights(path_1, adapter_name="adapter-1")
943+
>>> pipe.load_lora_weights(path_2, adapter_name="adapter-2")
944+
>>> pipe.set_adapters("adapter-1")
945+
>>> image_1 = pipe(**kwargs)
946+
>>> # switch to adapter-2, offload adapter-1
947+
>>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cpu")
948+
>>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cuda:0")
949+
>>> pipe.set_adapters("adapter-2")
950+
>>> image_2 = pipe(**kwargs)
951+
>>> # switch back to adapter-1, offload adapter-2
952+
>>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cpu")
953+
>>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cuda:0")
954+
>>> pipe.set_adapters("adapter-1")
955+
>>> ...
956+
```
957+
937958
Args:
938959
adapter_names (`List[str]`):
939960
List of adapters to send device to.
@@ -949,6 +970,10 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
949970
for module in model.modules():
950971
if isinstance(module, BaseTunerLayer):
951972
for adapter_name in adapter_names:
973+
if adapter_name not in module.lora_A:
974+
# it is sufficient to check lora_A
975+
continue
976+
952977
module.lora_A[adapter_name].to(device)
953978
module.lora_B[adapter_name].to(device)
954979
# this is a param, not a module, so device placement is not in-place -> re-assign

tests/lora/test_lora_layers_sd.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_integration_move_lora_cpu(self):
120120

121121
self.assertTrue(
122122
check_if_lora_correctly_set(pipe.unet),
123-
"Lora not correctly set in text encoder",
123+
"Lora not correctly set in unet",
124124
)
125125

126126
# We will offload the first adapter in CPU and check if the offloading
@@ -187,7 +187,7 @@ def test_integration_move_lora_dora_cpu(self):
187187

188188
self.assertTrue(
189189
check_if_lora_correctly_set(pipe.unet),
190-
"Lora not correctly set in text encoder",
190+
"Lora not correctly set in unet",
191191
)
192192

193193
for name, param in pipe.unet.named_parameters():
@@ -208,6 +208,53 @@ def test_integration_move_lora_dora_cpu(self):
208208
if "lora_" in name:
209209
self.assertNotEqual(param.device, torch.device("cpu"))
210210

211+
@slow
212+
@require_torch_accelerator
213+
def test_integration_set_lora_device_different_target_layers(self):
214+
# fixes a bug that occurred when calling set_lora_device with multiple adapters loaded that target different
215+
# layers, see #11833
216+
from peft import LoraConfig
217+
218+
path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
219+
pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
220+
# configs partly target the same, partly different layers
221+
config0 = LoraConfig(target_modules=["to_k", "to_v"])
222+
config1 = LoraConfig(target_modules=["to_k", "to_q"])
223+
pipe.unet.add_adapter(config0, adapter_name="adapter-0")
224+
pipe.unet.add_adapter(config1, adapter_name="adapter-1")
225+
pipe = pipe.to(torch_device)
226+
227+
self.assertTrue(
228+
check_if_lora_correctly_set(pipe.unet),
229+
"Lora not correctly set in unet",
230+
)
231+
232+
# sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix
233+
modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")}
234+
modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")}
235+
self.assertNotEqual(modules_adapter_0, modules_adapter_1)
236+
self.assertTrue(modules_adapter_0 - modules_adapter_1)
237+
self.assertTrue(modules_adapter_1 - modules_adapter_0)
238+
239+
# setting both separately works
240+
pipe.set_lora_device(["adapter-0"], "cpu")
241+
pipe.set_lora_device(["adapter-1"], "cpu")
242+
243+
for name, module in pipe.unet.named_modules():
244+
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
245+
self.assertTrue(module.weight.device == torch.device("cpu"))
246+
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
247+
self.assertTrue(module.weight.device == torch.device("cpu"))
248+
249+
# setting both at once also works
250+
pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device)
251+
252+
for name, module in pipe.unet.named_modules():
253+
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
254+
self.assertTrue(module.weight.device != torch.device("cpu"))
255+
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
256+
self.assertTrue(module.weight.device != torch.device("cpu"))
257+
211258

212259
@slow
213260
@nightly

0 commit comments

Comments
 (0)