@@ -1260,6 +1260,7 @@ def load_lora_weights(
12601260 adapter_name = adapter_name ,
12611261 _pipeline = self ,
12621262 low_cpu_mem_usage = low_cpu_mem_usage ,
1263+ hotswap = hotswap ,
12631264 )
12641265
12651266 text_encoder_state_dict = {k : v for k , v in state_dict .items () if "text_encoder." in k }
@@ -1292,7 +1293,7 @@ def load_lora_weights(
12921293
12931294 @classmethod
12941295 def load_lora_into_transformer (
1295- cls , state_dict , transformer , adapter_name = None , _pipeline = None , low_cpu_mem_usage = False
1296+ cls , state_dict , transformer , adapter_name = None , _pipeline = None , low_cpu_mem_usage = False , hotswap : bool = False
12961297 ):
12971298 """
12981299 This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -1310,6 +1311,13 @@ def load_lora_into_transformer(
13101311 low_cpu_mem_usage (`bool`, *optional*):
13111312 Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
13121313 weights.
1314+ hotswap : (`bool`, *optional*)
1315+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1316+ in-place. This means that, instead of loading an additional adapter, this will take the existing
1317+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
1318+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1319+ torch.compile, loading the new adapter does not require recompilation of the model. When using
1320+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
13131321 """
13141322 if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
13151323 raise ValueError (
@@ -1324,6 +1332,7 @@ def load_lora_into_transformer(
13241332 adapter_name = adapter_name ,
13251333 _pipeline = _pipeline ,
13261334 low_cpu_mem_usage = low_cpu_mem_usage ,
1335+ hotswap = hotswap ,
13271336 )
13281337
13291338 @classmethod
@@ -1786,6 +1795,7 @@ def load_lora_weights(
17861795 adapter_name = adapter_name ,
17871796 _pipeline = self ,
17881797 low_cpu_mem_usage = low_cpu_mem_usage ,
1798+ hotswap = hotswap ,
17891799 )
17901800
17911801 if len (transformer_norm_state_dict ) > 0 :
@@ -1811,7 +1821,14 @@ def load_lora_weights(
18111821
18121822 @classmethod
18131823 def load_lora_into_transformer (
1814- cls , state_dict , network_alphas , transformer , adapter_name = None , _pipeline = None , low_cpu_mem_usage = False
1824+ cls ,
1825+ state_dict ,
1826+ network_alphas ,
1827+ transformer ,
1828+ adapter_name = None ,
1829+ _pipeline = None ,
1830+ low_cpu_mem_usage = False ,
1831+ hotswap : bool = False ,
18151832 ):
18161833 """
18171834 This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -1833,6 +1850,13 @@ def load_lora_into_transformer(
18331850 low_cpu_mem_usage (`bool`, *optional*):
18341851 Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
18351852 weights.
1853+ hotswap : (`bool`, *optional*)
1854+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
1855+ in-place. This means that, instead of loading an additional adapter, this will take the existing
1856+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
1857+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
1858+ torch.compile, loading the new adapter does not require recompilation of the model. When using
1859+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
18361860 """
18371861 if low_cpu_mem_usage and not is_peft_version (">=" , "0.13.1" ):
18381862 raise ValueError (
@@ -1850,6 +1874,7 @@ def load_lora_into_transformer(
18501874 adapter_name = adapter_name ,
18511875 _pipeline = _pipeline ,
18521876 low_cpu_mem_usage = low_cpu_mem_usage ,
1877+ hotswap = hotswap ,
18531878 )
18541879
18551880 @classmethod
0 commit comments