Skip to content

Commit dec4d10

Browse files
Add hotswap argument to load_lora_into_transformer
For SD3 and Flux. Use shorter docstring for brevity.
1 parent 7f72d0b commit dec4d10

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,6 +1260,7 @@ def load_lora_weights(
12601260
adapter_name=adapter_name,
12611261
_pipeline=self,
12621262
low_cpu_mem_usage=low_cpu_mem_usage,
1263+
hotswap=hotswap,
12631264
)
12641265

12651266
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
@@ -1292,7 +1293,7 @@ def load_lora_weights(
12921293

12931294
@classmethod
12941295
def load_lora_into_transformer(
1295-
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
1296+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
12961297
):
12971298
"""
12981299
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -1310,6 +1311,13 @@ def load_lora_into_transformer(
13101311
low_cpu_mem_usage (`bool`, *optional*):
13111312
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
13121313
weights.
1314+
hotswap : (`bool`, *optional*)
1315+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1316+
in-place. This means that, instead of loading an additional adapter, this will take the existing
1317+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
1318+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1319+
torch.compile, loading the new adapter does not require recompilation of the model. When using
1320+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
13131321
"""
13141322
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
13151323
raise ValueError(
@@ -1324,6 +1332,7 @@ def load_lora_into_transformer(
13241332
adapter_name=adapter_name,
13251333
_pipeline=_pipeline,
13261334
low_cpu_mem_usage=low_cpu_mem_usage,
1335+
hotswap=hotswap,
13271336
)
13281337

13291338
@classmethod
@@ -1786,6 +1795,7 @@ def load_lora_weights(
17861795
adapter_name=adapter_name,
17871796
_pipeline=self,
17881797
low_cpu_mem_usage=low_cpu_mem_usage,
1798+
hotswap=hotswap,
17891799
)
17901800

17911801
if len(transformer_norm_state_dict) > 0:
@@ -1811,7 +1821,14 @@ def load_lora_weights(
18111821

18121822
@classmethod
18131823
def load_lora_into_transformer(
1814-
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
1824+
cls,
1825+
state_dict,
1826+
network_alphas,
1827+
transformer,
1828+
adapter_name=None,
1829+
_pipeline=None,
1830+
low_cpu_mem_usage=False,
1831+
hotswap: bool = False,
18151832
):
18161833
"""
18171834
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -1833,6 +1850,13 @@ def load_lora_into_transformer(
18331850
low_cpu_mem_usage (`bool`, *optional*):
18341851
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
18351852
weights.
1853+
hotswap : (`bool`, *optional*)
1854+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1855+
in-place. This means that, instead of loading an additional adapter, this will take the existing
1856+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
1857+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1858+
torch.compile, loading the new adapter does not require recompilation of the model. When using
1859+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
18361860
"""
18371861
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
18381862
raise ValueError(
@@ -1850,6 +1874,7 @@ def load_lora_into_transformer(
18501874
adapter_name=adapter_name,
18511875
_pipeline=_pipeline,
18521876
low_cpu_mem_usage=low_cpu_mem_usage,
1877+
hotswap=hotswap,
18531878
)
18541879

18551880
@classmethod

0 commit comments

Comments
 (0)