1717import json
1818import os
1919import random
20+ import re
2021import shutil
2122import sys
2223import tempfile
@@ -2239,12 +2240,23 @@ def get_dummy_input(self):
22392240 return pipeline_inputs
22402241
22412242 def check_pipeline_hotswap (self , do_compile , rank0 , rank1 , target_modules ):
2242- # Similar to check_hotswap but more realistic: check a whole pipeline to be closer to how users would use it
2243- from peft .utils .hotswap import prepare_model_for_compiled_hotswap
2244-
2243+ """
2244+ Check that hotswapping works on a pipeline.
2245+
2246+ Steps:
2247+ - create 2 LoRA adapters and save them
2248+ - load the first adapter
2249+ - hotswap the second adapter
2250+ - check that the outputs are correct
2251+ - optionally compile the model
2252+
2253+ Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
2254+ fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
2255+ fine.
2256+ """
2257+ # create 2 adapters with different ranks and alphas
22452258 dummy_input = self .get_dummy_input ()
22462259 pipeline = StableDiffusionPipeline .from_pretrained ("hf-internal-testing/tiny-sd-pipe" ).to (torch_device )
2247-
22482260 alpha0 , alpha1 = rank0 , rank1
22492261 max_rank = max ([rank0 , rank1 ])
22502262 lora_config0 = self .get_unet_lora_config (rank0 , alpha0 , target_modules )
@@ -2266,6 +2278,7 @@ def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules):
22662278 assert not (output1_before == 0 ).all ()
22672279
22682280 with tempfile .TemporaryDirectory () as tmp_dirname :
2281+ # save the adapter checkpoints
22692282 lora0_state_dicts = self .get_lora_state_dicts ({"unet" : pipeline .unet }, adapter_name = "adapter0" )
22702283 StableDiffusionPipeline .save_lora_weights (
22712284 save_directory = os .path .join (tmp_dirname , "adapter0" ), safe_serialization = True , ** lora0_state_dicts
@@ -2276,17 +2289,16 @@ def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules):
22762289 )
22772290 del pipeline
22782291
2292+ # load the first adapter
22792293 pipeline = StableDiffusionPipeline .from_pretrained ("hf-internal-testing/tiny-sd-pipe" ).to (torch_device )
2294+ if do_compile or (rank0 != rank1 ):
2295+ # no need to prepare if the model is not compiled or if the ranks are identical
2296+ pipeline .enable_lora_hotswap (target_rank = max_rank )
2297+
22802298 file_name0 = os .path .join (tmp_dirname , "adapter0" , "pytorch_lora_weights.safetensors" )
22812299 file_name1 = os .path .join (tmp_dirname , "adapter1" , "pytorch_lora_weights.safetensors" )
22822300
22832301 pipeline .load_lora_weights (file_name0 )
2284- if do_compile or (rank0 != rank1 ):
2285- prepare_model_for_compiled_hotswap (
2286- pipeline .unet ,
2287- config = {"adapter0" : lora_config0 , "adapter1" : lora_config1 },
2288- target_rank = max_rank ,
2289- )
22902302 if do_compile :
22912303 pipeline .unet = torch .compile (pipeline .unet , mode = "reduce-overhead" )
22922304
@@ -2295,6 +2307,7 @@ def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules):
22952307 # sanity check: still same result
22962308 assert np .allclose (output0_before , output0_after , atol = tol , rtol = tol )
22972309
2310+ # hotswap the 2nd adapter
22982311 pipeline .load_lora_weights (file_name1 , hotswap = True , adapter_name = "default_0" )
22992312 output1_after = pipeline (** dummy_input , generator = torch .manual_seed (0 ))[0 ]
23002313
@@ -2327,3 +2340,12 @@ def test_hotswapping_compiled_pipline_both_linear_and_conv2d(self, rank0, rank1)
23272340 target_modules = ["to_q" , "conv" ]
23282341 with torch ._dynamo .config .patch (error_on_recompile = True ):
23292342 self .check_pipeline_hotswap (do_compile = True , rank0 = rank0 , rank1 = rank1 , target_modules = target_modules )
2343+
2344+ def test_enable_lora_hotswap_called_too_late_raises (self ):
2345+ # ensure that enable_lora_hotswap is called before loading the first adapter
2346+ lora_config = self .get_unet_lora_config (8 , 8 , target_modules = ["to_q" ])
2347+ pipeline = StableDiffusionPipeline .from_pretrained ("hf-internal-testing/tiny-sd-pipe" ).to (torch_device )
2348+ pipeline .unet .add_adapter (lora_config )
2349+ msg = re .escape ("Call `enable_lora_hotswap` before loading the first adapter." )
2350+ with self .assertRaisesRegex (RuntimeError , msg ):
2351+ pipeline .enable_lora_hotswap (target_rank = 32 )
0 commit comments