Skip to content

Commit 387ddf6

Browse files
Update more methods with hotswap argument
- SDXL - SD3 - Flux No changes were made to load_lora_into_transformer.
1 parent c3c1bdf commit 387ddf6

File tree

1 file changed

+153
-2
lines changed

1 file changed

+153
-2
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 153 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,7 @@ def load_lora_into_text_encoder(
880880
adapter_name=None,
881881
_pipeline=None,
882882
low_cpu_mem_usage=False,
883+
hotswap: bool = False,
883884
):
884885
"""
885886
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -905,6 +906,29 @@ def load_lora_into_text_encoder(
905906
low_cpu_mem_usage (`bool`, *optional*):
906907
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
907908
weights.
909+
hotswap : (`bool`, *optional*)
910+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
911+
in-place. This means that, instead of loading an additional adapter, this will take the existing
912+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
913+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
914+
torch.compile, loading the new adapter does not require recompilation of the model. When using
915+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
916+
917+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
918+
to call an additional method before loading the adapter:
919+
920+
```py
921+
pipeline = ... # load diffusers pipeline
922+
max_rank = ... # the highest rank among all LoRAs that you want to load
923+
# call *before* compiling and loading the LoRA adapter
924+
pipeline.enable_lora_hotswap(target_rank=max_rank)
925+
pipeline.load_lora_weights(file_name)
926+
# optionally compile the model now
927+
```
928+
929+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
930+
limitations to this technique, which are documented here:
931+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
908932
"""
909933
_load_lora_into_text_encoder(
910934
state_dict=state_dict,
@@ -916,6 +940,7 @@ def load_lora_into_text_encoder(
916940
adapter_name=adapter_name,
917941
_pipeline=_pipeline,
918942
low_cpu_mem_usage=low_cpu_mem_usage,
943+
hotswap=hotswap,
919944
)
920945

921946
@classmethod
@@ -1155,7 +1180,11 @@ def lora_state_dict(
11551180
return state_dict
11561181

11571182
def load_lora_weights(
1158-
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
1183+
self,
1184+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
1185+
adapter_name=None,
1186+
hotswap: bool = False,
1187+
**kwargs,
11591188
):
11601189
"""
11611190
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
@@ -1178,6 +1207,26 @@ def load_lora_weights(
11781207
low_cpu_mem_usage (`bool`, *optional*):
11791208
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
11801209
weights.
1210+
hotswap : (`bool`, *optional*)
1211+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1212+
in-place. This means that, instead of loading an additional adapter, this will take the existing
1213+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
1214+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1215+
torch.compile, loading the new adapter does not require recompilation of the model. When using
1216+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. If the new
1217+
adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an
1218+
additional method before loading the adapter:
1219+
```py
1220+
pipeline = ... # load diffusers pipeline
1221+
max_rank = ... # the highest rank among all LoRAs that you want to load
1222+
# call *before* compiling and loading the LoRA adapter
1223+
pipeline.enable_lora_hotswap(target_rank=max_rank)
1224+
pipeline.load_lora_weights(file_name)
1225+
# optionally compile the model now
1226+
```
1227+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
1228+
limitations to this technique, which are documented here:
1229+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
11811230
kwargs (`dict`, *optional*):
11821231
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
11831232
"""
@@ -1224,6 +1273,7 @@ def load_lora_weights(
12241273
adapter_name=adapter_name,
12251274
_pipeline=self,
12261275
low_cpu_mem_usage=low_cpu_mem_usage,
1276+
hotswap=hotswap,
12271277
)
12281278

12291279
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
@@ -1237,6 +1287,7 @@ def load_lora_weights(
12371287
adapter_name=adapter_name,
12381288
_pipeline=self,
12391289
low_cpu_mem_usage=low_cpu_mem_usage,
1290+
hotswap=hotswap,
12401291
)
12411292

12421293
@classmethod
@@ -1287,6 +1338,7 @@ def load_lora_into_text_encoder(
12871338
adapter_name=None,
12881339
_pipeline=None,
12891340
low_cpu_mem_usage=False,
1341+
hotswap: bool = False,
12901342
):
12911343
"""
12921344
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1312,6 +1364,29 @@ def load_lora_into_text_encoder(
13121364
low_cpu_mem_usage (`bool`, *optional*):
13131365
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
13141366
weights.
1367+
hotswap : (`bool`, *optional*)
1368+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1369+
in-place. This means that, instead of loading an additional adapter, this will take the existing
1370+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
1371+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1372+
torch.compile, loading the new adapter does not require recompilation of the model. When using
1373+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
1374+
1375+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
1376+
to call an additional method before loading the adapter:
1377+
1378+
```py
1379+
pipeline = ... # load diffusers pipeline
1380+
max_rank = ... # the highest rank among all LoRAs that you want to load
1381+
# call *before* compiling and loading the LoRA adapter
1382+
pipeline.enable_lora_hotswap(target_rank=max_rank)
1383+
pipeline.load_lora_weights(file_name)
1384+
# optionally compile the model now
1385+
```
1386+
1387+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
1388+
limitations to this technique, which are documented here:
1389+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
13151390
"""
13161391
_load_lora_into_text_encoder(
13171392
state_dict=state_dict,
@@ -1323,6 +1398,7 @@ def load_lora_into_text_encoder(
13231398
adapter_name=adapter_name,
13241399
_pipeline=_pipeline,
13251400
low_cpu_mem_usage=low_cpu_mem_usage,
1401+
hotswap=hotswap,
13261402
)
13271403

13281404
@classmethod
@@ -1600,7 +1676,11 @@ def lora_state_dict(
16001676
return state_dict
16011677

16021678
def load_lora_weights(
1603-
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
1679+
self,
1680+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
1681+
adapter_name=None,
1682+
hotswap: bool = False,
1683+
**kwargs,
16041684
):
16051685
"""
16061686
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
@@ -1625,6 +1705,26 @@ def load_lora_weights(
16251705
low_cpu_mem_usage (`bool`, *optional*):
16261706
`Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
16271707
weights.
1708+
hotswap : (`bool`, *optional*)
1709+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1710+
in-place. This means that, instead of loading an additional adapter, this will take the existing
1711+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
1712+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1713+
torch.compile, loading the new adapter does not require recompilation of the model. When using
1714+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. If the new
1715+
adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an
1716+
additional method before loading the adapter:
1717+
```py
1718+
pipeline = ... # load diffusers pipeline
1719+
max_rank = ... # the highest rank among all LoRAs that you want to load
1720+
# call *before* compiling and loading the LoRA adapter
1721+
pipeline.enable_lora_hotswap(target_rank=max_rank)
1722+
pipeline.load_lora_weights(file_name)
1723+
# optionally compile the model now
1724+
```
1725+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
1726+
limitations to this technique, which are documented here:
1727+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
16281728
"""
16291729
if not USE_PEFT_BACKEND:
16301730
raise ValueError("PEFT backend is required for this method.")
@@ -1706,6 +1806,7 @@ def load_lora_weights(
17061806
adapter_name=adapter_name,
17071807
_pipeline=self,
17081808
low_cpu_mem_usage=low_cpu_mem_usage,
1809+
hotswap=hotswap,
17091810
)
17101811

17111812
@classmethod
@@ -1817,6 +1918,7 @@ def load_lora_into_text_encoder(
18171918
adapter_name=None,
18181919
_pipeline=None,
18191920
low_cpu_mem_usage=False,
1921+
hotswap: bool = False,
18201922
):
18211923
"""
18221924
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1842,6 +1944,29 @@ def load_lora_into_text_encoder(
18421944
low_cpu_mem_usage (`bool`, *optional*):
18431945
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
18441946
weights.
1947+
hotswap : (`bool`, *optional*)
1948+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1949+
in-place. This means that, instead of loading an additional adapter, this will take the existing
1950+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
1951+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1952+
torch.compile, loading the new adapter does not require recompilation of the model. When using
1953+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
1954+
1955+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
1956+
to call an additional method before loading the adapter:
1957+
1958+
```py
1959+
pipeline = ... # load diffusers pipeline
1960+
max_rank = ... # the highest rank among all LoRAs that you want to load
1961+
# call *before* compiling and loading the LoRA adapter
1962+
pipeline.enable_lora_hotswap(target_rank=max_rank)
1963+
pipeline.load_lora_weights(file_name)
1964+
# optionally compile the model now
1965+
```
1966+
1967+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
1968+
limitations to this technique, which are documented here:
1969+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
18451970
"""
18461971
_load_lora_into_text_encoder(
18471972
state_dict=state_dict,
@@ -1853,6 +1978,7 @@ def load_lora_into_text_encoder(
18531978
adapter_name=adapter_name,
18541979
_pipeline=_pipeline,
18551980
low_cpu_mem_usage=low_cpu_mem_usage,
1981+
hotswap=hotswap,
18561982
)
18571983

18581984
@classmethod
@@ -2312,6 +2438,7 @@ def load_lora_into_text_encoder(
23122438
adapter_name=None,
23132439
_pipeline=None,
23142440
low_cpu_mem_usage=False,
2441+
hotswap: bool = False,
23152442
):
23162443
"""
23172444
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -2337,6 +2464,29 @@ def load_lora_into_text_encoder(
23372464
low_cpu_mem_usage (`bool`, *optional*):
23382465
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
23392466
weights.
2467+
hotswap : (`bool`, *optional*)
2468+
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
2469+
in-place. This means that, instead of loading an additional adapter, this will take the existing
2470+
adapter weights and replace them with the weights of the new adapter. This can be faster and more
2471+
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
2472+
torch.compile, loading the new adapter does not require recompilation of the model. When using
2473+
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
2474+
2475+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
2476+
to call an additional method before loading the adapter:
2477+
2478+
```py
2479+
pipeline = ... # load diffusers pipeline
2480+
max_rank = ... # the highest rank among all LoRAs that you want to load
2481+
# call *before* compiling and loading the LoRA adapter
2482+
pipeline.enable_lora_hotswap(target_rank=max_rank)
2483+
pipeline.load_lora_weights(file_name)
2484+
# optionally compile the model now
2485+
```
2486+
2487+
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
2488+
limitations to this technique, which are documented here:
2489+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
23402490
"""
23412491
_load_lora_into_text_encoder(
23422492
state_dict=state_dict,
@@ -2348,6 +2498,7 @@ def load_lora_into_text_encoder(
23482498
adapter_name=adapter_name,
23492499
_pipeline=_pipeline,
23502500
low_cpu_mem_usage=low_cpu_mem_usage,
2501+
hotswap=hotswap,
23512502
)
23522503

23532504
@classmethod

0 commit comments

Comments
 (0)