Skip to content

Commit 3a6677c

Browse files
Add possibility to ignore enabled at wrong time
1 parent 69b637d commit 3a6677c

File tree

5 files changed

+124
-35
lines changed

5 files changed

+124
-35
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,12 @@ def enable_lora_hotswap(self, **kwargs) -> None:
914914
Args:
915915
target_rank (`int`):
916916
The highest rank among all the adapters that will be loaded.
917+
check_correct (`str`, *optional*, defaults to `"error"`):
918+
How to handle the case when the model is already compiled, which should generally be avoided. The
919+
options are:
920+
- "error" (default): raise an error
921+
- "warn": issue a warning
922+
- "ignore": do nothing
917923
"""
918924
for key, component in self.components.items():
919925
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):

src/diffusers/loaders/lora_pipeline.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,12 @@ def load_lora_weights(
118118
to call an additional method before loading the adapter:
119119
120120
```py
121-
from peft.utils.hotswap import prepare_model_for_compiled_hotswap
122-
123-
model = ... # load diffusers model with first LoRA adapter
121+
pipeline = ... # load diffusers pipeline
124122
max_rank = ... # the highest rank among all LoRAs that you want to load
125-
prepare_model_for_compiled_hotswap(model, target_rank=max_rank) # call *before* compiling
126-
model = torch.compile(model)
127-
model.load_lora_adapter(..., hotswap=True) # now hotswap the 2nd adapter
123+
# call *before* compiling and loading the LoRA adapter
124+
pipeline.enable_lora_hotswap(target_rank=max_rank)
125+
pipeline.load_lora_weights(file_name)
126+
# optionally compile the model now
128127
```
129128
130129
There are some limitations to this technique, which are documented here:
@@ -330,13 +329,12 @@ def load_lora_into_unet(
330329
to call an additional method before loading the adapter:
331330
332331
```py
333-
from peft.utils.hotswap import prepare_model_for_compiled_hotswap
334-
335-
model = ... # load diffusers model with first LoRA adapter
332+
pipeline = ... # load diffusers pipeline
336333
max_rank = ... # the highest rank among all LoRAs that you want to load
337-
prepare_model_for_compiled_hotswap(model, target_rank=max_rank) # call *before* compiling
338-
model = torch.compile(model)
339-
model.load_lora_adapter(..., hotswap=True) # now hotswap the 2nd adapter
334+
# call *before* compiling and loading the LoRA adapter
335+
pipeline.enable_lora_hotswap(target_rank=max_rank)
336+
pipeline.load_lora_weights(file_name)
337+
# optionally compile the model now
340338
```
341339
342340
There are some limitations to this technique, which are documented here:
@@ -800,13 +798,12 @@ def load_lora_into_unet(
800798
to call an additional method before loading the adapter:
801799
802800
```py
803-
from peft.utils.hotswap import prepare_model_for_compiled_hotswap
804-
805-
model = ... # load diffusers model with first LoRA adapter
801+
pipeline = ... # load diffusers pipeline
806802
max_rank = ... # the highest rank among all LoRAs that you want to load
807-
prepare_model_for_compiled_hotswap(model, target_rank=max_rank) # call *before* compiling
808-
model = torch.compile(model)
809-
model.load_lora_adapter(..., hotswap=True) # now hotswap the 2nd adapter
803+
# call *before* compiling and loading the LoRA adapter
804+
pipeline.enable_lora_hotswap(target_rank=max_rank)
805+
pipeline.load_lora_weights(file_name)
806+
# optionally compile the model now
810807
```
811808
812809
There are some limitations to this technique, which are documented here:

src/diffusers/loaders/peft.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
from functools import partial
1818
from pathlib import Path
19-
from typing import Dict, List, Optional, Union
19+
from typing import Dict, List, Literal, Optional, Union
2020

2121
import safetensors
2222
import torch
@@ -144,8 +144,7 @@ def _optionally_disable_offloading(cls, _pipeline):
144144
def load_lora_adapter(
145145
self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs
146146
):
147-
r"""
148-
Loads a LoRA adapter into the underlying model.
147+
r"""Loads a LoRA adapter into the underlying model.
149148
150149
Parameters:
151150
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
@@ -194,21 +193,21 @@ def load_lora_adapter(
194193
However, the main advantage of hotswapping is that when the model is compiled with torch.compile,
195194
loading the new adapter does not require recompilation of the model.
196195
197-
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
198-
to call an additional method before loading the adapter:
196+
If the model is compiled, or if the new adapter and the old adapter have different ranks and/or LoRA
197+
alphas (i.e. scaling), you need to call an additional method before loading the adapter:
199198
200199
```py
201-
from peft.utils.hotswap import prepare_model_for_compiled_hotswap
202-
203-
model = ... # load diffusers model with first LoRA adapter
200+
pipeline = ... # load diffusers pipeline
204201
max_rank = ... # the highest rank among all LoRAs that you want to load
205-
prepare_model_for_compiled_hotswap(model, target_rank=max_rank) # call *before* compiling
206-
model = torch.compile(model)
207-
model.load_lora_adapter(..., hotswap=True) # now hotswap the 2nd adapter
202+
# call *before* compiling and loading the LoRA adapter
203+
pipeline.enable_lora_hotswap(target_rank=max_rank)
204+
pipeline.load_lora_weights(file_name)
205+
# optionally compile the model now
208206
```
209207
210208
There are some limitations to this technique, which are documented here:
211209
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
210+
212211
"""
213212
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
214213
from peft.tuners.tuners_utils import BaseTunerLayer
@@ -837,16 +836,35 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
837836
if hasattr(self, "peft_config"):
838837
self.peft_config.pop(adapter_name, None)
839838

840-
def enable_lora_hotswap(self, target_rank: int) -> None:
839+
def enable_lora_hotswap(
840+
self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error"
841+
) -> None:
841842
"""Enables the possibility to hotswap LoRA adapters.
842843
843844
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
844845
the loaded adapters differ.
845846
846847
Args:
847-
target_rank (`int`):
848+
target_rank (`int`, *optional*, defaults to `128`):
848849
The highest rank among all the adapters that will be loaded.
850+
851+
check_correct (`str`, *optional*, defaults to `"error"`):
852+
How to handle the case when the model is already compiled, which should generally be avoided. The
853+
options are:
854+
- "error" (default): raise an error
855+
- "warn": issue a warning
856+
- "ignore": do nothing
849857
"""
850858
if getattr(self, "peft_config", {}):
851-
raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.")
852-
self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank}
859+
if check_compiled == "error":
860+
raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.")
861+
elif check_compiled == "warn":
862+
logger.warning(
863+
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
864+
)
865+
elif check_compiled != "ignore":
866+
raise ValueError(
867+
f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead."
868+
)
869+
870+
self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled}

tests/models/test_modeling_common.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import unittest
2525
import unittest.mock as mock
2626
import uuid
27+
import warnings
2728
from collections import defaultdict
2829
from typing import Dict, List, Optional, Tuple, Union
2930

@@ -1827,11 +1828,44 @@ def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
18271828
with torch._dynamo.config.patch(error_on_recompile=True):
18281829
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules)
18291830

1830-
def test_enable_lora_hotswap_called_too_late_raises(self):
1831+
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
18311832
# ensure that enable_lora_hotswap is called before loading the first adapter
18321833
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
18331834
unet = self.get_small_unet()
18341835
unet.add_adapter(lora_config)
18351836
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
18361837
with self.assertRaisesRegex(RuntimeError, msg):
18371838
unet.enable_lora_hotswap(target_rank=32)
1839+
1840+
def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
1841+
# ensure that enable_lora_hotswap is called before loading the first adapter
1842+
from diffusers.loaders.peft import logger
1843+
1844+
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
1845+
unet = self.get_small_unet()
1846+
unet.add_adapter(lora_config)
1847+
msg = (
1848+
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
1849+
)
1850+
with self.assertLogs(logger=logger, level="WARNING") as cm:
1851+
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
1852+
assert any(msg in log for log in cm.output)
1853+
1854+
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
1855+
# check possibility to ignore the error/warning
1856+
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
1857+
unet = self.get_small_unet()
1858+
unet.add_adapter(lora_config)
1859+
with warnings.catch_warnings(record=True) as w:
1860+
warnings.simplefilter("always") # Capture all warnings
1861+
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
1862+
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
1863+
1864+
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
1865+
# check that wrong argument value raises an error
1866+
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
1867+
unet = self.get_small_unet()
1868+
unet.add_adapter(lora_config)
1869+
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
1870+
with self.assertRaisesRegex(ValueError, msg):
1871+
unet.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")

tests/pipelines/test_pipelines.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import traceback
2525
import unittest
2626
import unittest.mock as mock
27+
import warnings
2728

2829
import numpy as np
2930
import PIL.Image
@@ -2341,11 +2342,44 @@ def test_hotswapping_compiled_pipline_both_linear_and_conv2d(self, rank0, rank1)
23412342
with torch._dynamo.config.patch(error_on_recompile=True):
23422343
self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules)
23432344

2344-
def test_enable_lora_hotswap_called_too_late_raises(self):
2345+
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
23452346
# ensure that enable_lora_hotswap is called before loading the first adapter
23462347
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
23472348
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
23482349
pipeline.unet.add_adapter(lora_config)
23492350
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
23502351
with self.assertRaisesRegex(RuntimeError, msg):
23512352
pipeline.enable_lora_hotswap(target_rank=32)
2353+
2354+
def test_enable_lora_hotswap_called_after_adapter_added_warns(self):
2355+
# ensure that enable_lora_hotswap is called before loading the first adapter
2356+
from diffusers.loaders.peft import logger
2357+
2358+
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
2359+
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
2360+
pipeline.unet.add_adapter(lora_config)
2361+
msg = (
2362+
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
2363+
)
2364+
with self.assertLogs(logger=logger, level="WARNING") as cm:
2365+
pipeline.enable_lora_hotswap(target_rank=32, check_compiled="warn")
2366+
assert any(msg in log for log in cm.output)
2367+
2368+
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
2369+
# check possibility to ignore the error/warning
2370+
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
2371+
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
2372+
pipeline.unet.add_adapter(lora_config)
2373+
with warnings.catch_warnings(record=True) as w:
2374+
warnings.simplefilter("always") # Capture all warnings
2375+
pipeline.enable_lora_hotswap(target_rank=32, check_compiled="warn")
2376+
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
2377+
2378+
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
2379+
# check that wrong argument value raises an error
2380+
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
2381+
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
2382+
pipeline.unet.add_adapter(lora_config)
2383+
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
2384+
with self.assertRaisesRegex(ValueError, msg):
2385+
pipeline.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")

0 commit comments

Comments
 (0)