@@ -1791,7 +1791,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
17911791
17921792 file_name0 = os .path .join (os .path .join (tmp_dirname , "0" ), "pytorch_lora_weights.safetensors" )
17931793 file_name1 = os .path .join (os .path .join (tmp_dirname , "1" ), "pytorch_lora_weights.safetensors" )
1794- unet .load_lora_adapter (file_name0 , safe_serialization = True , adapter_name = "adapter0" )
1794+ unet .load_lora_adapter (file_name0 , safe_serialization = True , adapter_name = "adapter0" , prefix = None )
17951795
17961796 if do_compile :
17971797 unet = torch .compile (unet , mode = "reduce-overhead" )
@@ -1801,7 +1801,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
18011801 assert torch .allclose (output0_before , output0_after , atol = tol , rtol = tol )
18021802
18031803 # hotswap the 2nd adapter
1804- unet .load_lora_adapter (file_name1 , adapter_name = "adapter0" , hotswap = True )
1804+ unet .load_lora_adapter (file_name1 , adapter_name = "adapter0" , hotswap = True , prefix = None )
18051805
18061806 # we need to call forward to potentially trigger recompilation
18071807 with torch .inference_mode ():
@@ -1812,7 +1812,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
18121812 name = "does-not-exist"
18131813 msg = f"Trying to hotswap LoRA adapter '{ name } ' but there is no existing adapter by that name"
18141814 with self .assertRaisesRegex (ValueError , msg ):
1815- unet .load_lora_adapter (file_name1 , adapter_name = name , hotswap = True )
1815+ unet .load_lora_adapter (file_name1 , adapter_name = name , hotswap = True , prefix = None )
18161816
18171817 @parameterized .expand ([(11 , 11 ), (7 , 13 ), (13 , 7 )]) # important to test small to large and vice versa
18181818 def test_hotswapping_model (self , rank0 , rank1 ):
0 commit comments