Skip to content

Commit 1b834ec

Browse files
Add enable_lora_hotswap method
1 parent e40390d commit 1b834ec

File tree

5 files changed

+134
-25
lines changed

5 files changed

+134
-25
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,3 +898,17 @@ def lora_scale(self) -> float:
898898
# property function that returns the lora scale which can be set at run time by the pipeline.
899899
# if _lora_scale has not been set, return 1
900900
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
901+
902+
def enable_lora_hotswap(self, **kwargs) -> None:
903+
"""Enables the possibility to hotswap LoRA adapters.
904+
905+
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
906+
the loaded adapters differ.
907+
908+
Args:
909+
target_rank (`int`):
910+
The highest rank among all the adapters that will be loaded.
911+
"""
912+
for component in self.components.values():
913+
if hasattr(component, "enable_lora_hotswap"):
914+
component.enable_lora_hotswap(**kwargs)

src/diffusers/loaders/peft.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ class PeftAdapterMixin:
121121
"""
122122

123123
_hf_peft_config_loaded = False
124+
# kwargs for prepare_model_for_compiled_hotswap, if required
125+
_prepare_lora_hotswap_kwargs: Optional[dict] = None
124126

125127
@classmethod
126128
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
@@ -325,9 +327,13 @@ def load_lora_adapter(
325327
if is_peft_version(">=", "0.13.1"):
326328
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
327329

328-
if hotswap:
330+
if hotswap or (self._prepare_lora_hotswap_kwargs is not None):
329331
if is_peft_version(">", "0.14.0"):
330-
from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict
332+
from peft.utils.hotswap import (
333+
check_hotswap_configs_compatible,
334+
hotswap_adapter_from_state_dict,
335+
prepare_model_for_compiled_hotswap,
336+
)
331337
else:
332338
msg = (
333339
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it "
@@ -366,6 +372,19 @@ def map_state_dict_for_hotswap(sd):
366372
else:
367373
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
368374
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
375+
376+
if self._prepare_lora_hotswap_kwargs is not None:
377+
# For hotswapping of compiled models or adapters with different ranks.
378+
# If the user called enable_lora_hotswap, we need to ensure it is called:
379+
# - after the first adapter was loaded
380+
# - before the model is compiled and the 2nd adapter is being hotswapped in
381+
# Therefore, it needs to be called here
382+
prepare_model_for_compiled_hotswap(
383+
self, config=lora_config, **self._prepare_lora_hotswap_kwargs
384+
)
385+
# We only want to call prepare_model_for_compiled_hotswap once
386+
self._prepare_lora_hotswap_kwargs = None
387+
369388
except Exception as e:
370389
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
371390
if hasattr(self, "peft_config"):
@@ -816,3 +835,17 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
816835
# Pop also the corresponding adapter from the config
817836
if hasattr(self, "peft_config"):
818837
self.peft_config.pop(adapter_name, None)
838+
839+
def enable_lora_hotswap(self, target_rank: int) -> None:
840+
"""Enables the possibility to hotswap LoRA adapters.
841+
842+
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
843+
the loaded adapters differ.
844+
845+
Args:
846+
target_rank (`int`):
847+
The highest rank among all the adapters that will be loaded.
848+
"""
849+
if getattr(self, "peft_config", {}):
850+
raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.")
851+
self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank}

src/diffusers/loaders/unet.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from collections import defaultdict
1616
from contextlib import nullcontext
1717
from pathlib import Path
18-
from typing import Callable, Dict, Union
18+
from typing import Callable, Dict, Optional, Union
1919

2020
import safetensors
2121
import torch
@@ -62,6 +62,8 @@ class UNet2DConditionLoadersMixin:
6262

6363
text_encoder_name = TEXT_ENCODER_NAME
6464
unet_name = UNET_NAME
65+
# kwargs for prepare_model_for_compiled_hotswap, if required
66+
_prepare_lora_hotswap_kwargs: Optional[dict] = None
6567

6668
@validate_hf_hub_args
6769
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
@@ -377,9 +379,13 @@ def _process_lora(
377379
if is_peft_version(">=", "0.13.1"):
378380
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
379381

380-
if hotswap:
382+
if hotswap or (self._prepare_lora_hotswap_kwargs is not None):
381383
if is_peft_version(">", "0.14.0"):
382-
from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict
384+
from peft.utils.hotswap import (
385+
check_hotswap_configs_compatible,
386+
hotswap_adapter_from_state_dict,
387+
prepare_model_for_compiled_hotswap,
388+
)
383389
else:
384390
msg = (
385391
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it "
@@ -417,6 +423,19 @@ def map_state_dict_for_hotswap(sd):
417423
else:
418424
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
419425
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
426+
427+
if self._prepare_lora_hotswap_kwargs is not None:
428+
# For hotswapping of compiled models or adapters with different ranks.
429+
# If the user called enable_lora_hotswap, we need to ensure it is called:
430+
# - after the first adapter was loaded
431+
# - before the model is compiled and the 2nd adapter is being hotswapped in
432+
# Therefore, it needs to be called here
433+
prepare_model_for_compiled_hotswap(
434+
self, config=lora_config, **self._prepare_lora_hotswap_kwargs
435+
)
436+
# We only want to call prepare_model_for_compiled_hotswap once
437+
self._prepare_lora_hotswap_kwargs = None
438+
420439
except Exception as e:
421440
# TODO: add test in line with:
422441
# https://github.com/huggingface/diffusers/pull/10188/files#diff-b544edcc938e163009735ef4fa963abd0a41615c175552160c9e0f94ceb7f552
@@ -1002,3 +1021,17 @@ def _load_ip_adapter_loras(self, state_dicts):
10021021
}
10031022
)
10041023
return lora_dicts
1024+
1025+
def enable_lora_hotswap(self, target_rank: int) -> None:
1026+
"""Enables the possibility to hotswap LoRA adapters.
1027+
1028+
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
1029+
the loaded adapters differ.
1030+
1031+
Args:
1032+
target_rank (`int`):
1033+
The highest rank among all the adapters that will be loaded.
1034+
"""
1035+
if getattr(self, "peft_config", {}):
1036+
raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.")
1037+
self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank}

tests/models/test_modeling_common.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,10 +1638,8 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules):
16381638
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
16391639
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
16401640
fine.
1641-
16421641
"""
1643-
from peft.utils.hotswap import prepare_model_for_compiled_hotswap
1644-
1642+
# create 2 adapters with different ranks and alphas
16451643
dummy_input = self.get_dummy_input()
16461644
alpha0, alpha1 = rank0, rank1
16471645
max_rank = max([rank0, rank1])
@@ -1665,29 +1663,29 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules):
16651663
assert not (output1_before == 0).all()
16661664

16671665
with tempfile.TemporaryDirectory() as tmp_dirname:
1666+
# save the adapter checkpoints
16681667
unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
16691668
unet.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
16701669
del unet
16711670

1671+
# load the first adapter
16721672
unet = self.get_small_unet()
1673+
if do_compile or (rank0 != rank1):
1674+
# no need to prepare if the model is not compiled or if the ranks are identical
1675+
unet.enable_lora_hotswap(target_rank=max_rank)
1676+
16731677
file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors")
16741678
file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors")
16751679
unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0")
16761680

1677-
if do_compile or (rank0 != rank1):
1678-
# no need to prepare if the model is not compiled or if the ranks are identical
1679-
prepare_model_for_compiled_hotswap(
1680-
unet,
1681-
config={"adapter0": lora_config0, "adapter1": lora_config1},
1682-
target_rank=max_rank,
1683-
)
16841681
if do_compile:
16851682
unet = torch.compile(unet, mode="reduce-overhead")
16861683

16871684
with torch.inference_mode():
16881685
output0_after = unet(**dummy_input)["sample"]
16891686
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
16901687

1688+
# hotswap the 2nd adapter
16911689
unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True)
16921690

16931691
# we need to call forward to potentially trigger recompilation
@@ -1727,3 +1725,12 @@ def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
17271725
target_modules = ["to_q", "conv"]
17281726
with torch._dynamo.config.patch(error_on_recompile=True):
17291727
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules)
1728+
1729+
def test_enable_lora_hotswap_called_too_late_raises(self):
1730+
# ensure that enable_lora_hotswap is called before loading the first adapter
1731+
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
1732+
unet = self.get_small_unet()
1733+
unet.add_adapter(lora_config)
1734+
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
1735+
with self.assertRaisesRegex(RuntimeError, msg):
1736+
unet.enable_lora_hotswap(target_rank=32)

tests/pipelines/test_pipelines.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import json
1818
import os
1919
import random
20+
import re
2021
import shutil
2122
import sys
2223
import tempfile
@@ -2239,12 +2240,23 @@ def get_dummy_input(self):
22392240
return pipeline_inputs
22402241

22412242
def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules):
2242-
# Similar to check_hotswap but more realistic: check a whole pipeline to be closer to how users would use it
2243-
from peft.utils.hotswap import prepare_model_for_compiled_hotswap
2244-
2243+
"""
2244+
Check that hotswapping works on a pipeline.
2245+
2246+
Steps:
2247+
- create 2 LoRA adapters and save them
2248+
- load the first adapter
2249+
- hotswap the second adapter
2250+
- check that the outputs are correct
2251+
- optionally compile the model
2252+
2253+
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
2254+
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
2255+
fine.
2256+
"""
2257+
# create 2 adapters with different ranks and alphas
22452258
dummy_input = self.get_dummy_input()
22462259
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
2247-
22482260
alpha0, alpha1 = rank0, rank1
22492261
max_rank = max([rank0, rank1])
22502262
lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules)
@@ -2266,6 +2278,7 @@ def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules):
22662278
assert not (output1_before == 0).all()
22672279

22682280
with tempfile.TemporaryDirectory() as tmp_dirname:
2281+
# save the adapter checkpoints
22692282
lora0_state_dicts = self.get_lora_state_dicts({"unet": pipeline.unet}, adapter_name="adapter0")
22702283
StableDiffusionPipeline.save_lora_weights(
22712284
save_directory=os.path.join(tmp_dirname, "adapter0"), safe_serialization=True, **lora0_state_dicts
@@ -2276,17 +2289,16 @@ def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules):
22762289
)
22772290
del pipeline
22782291

2292+
# load the first adapter
22792293
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
2294+
if do_compile or (rank0 != rank1):
2295+
# no need to prepare if the model is not compiled or if the ranks are identical
2296+
pipeline.enable_lora_hotswap(target_rank=max_rank)
2297+
22802298
file_name0 = os.path.join(tmp_dirname, "adapter0", "pytorch_lora_weights.safetensors")
22812299
file_name1 = os.path.join(tmp_dirname, "adapter1", "pytorch_lora_weights.safetensors")
22822300

22832301
pipeline.load_lora_weights(file_name0)
2284-
if do_compile or (rank0 != rank1):
2285-
prepare_model_for_compiled_hotswap(
2286-
pipeline.unet,
2287-
config={"adapter0": lora_config0, "adapter1": lora_config1},
2288-
target_rank=max_rank,
2289-
)
22902302
if do_compile:
22912303
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead")
22922304

@@ -2295,6 +2307,7 @@ def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules):
22952307
# sanity check: still same result
22962308
assert np.allclose(output0_before, output0_after, atol=tol, rtol=tol)
22972309

2310+
# hotswap the 2nd adapter
22982311
pipeline.load_lora_weights(file_name1, hotswap=True, adapter_name="default_0")
22992312
output1_after = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
23002313

@@ -2327,3 +2340,12 @@ def test_hotswapping_compiled_pipline_both_linear_and_conv2d(self, rank0, rank1)
23272340
target_modules = ["to_q", "conv"]
23282341
with torch._dynamo.config.patch(error_on_recompile=True):
23292342
self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules)
2343+
2344+
def test_enable_lora_hotswap_called_too_late_raises(self):
2345+
# ensure that enable_lora_hotswap is called before loading the first adapter
2346+
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
2347+
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
2348+
pipeline.unet.add_adapter(lora_config)
2349+
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
2350+
with self.assertRaisesRegex(RuntimeError, msg):
2351+
pipeline.enable_lora_hotswap(target_rank=32)

0 commit comments

Comments
 (0)