@@ -121,7 +121,7 @@ def test_integration_move_lora_cpu(self):
121121
122122 self .assertTrue (
123123 check_if_lora_correctly_set (pipe .unet ),
124- "Lora not correctly set in text encoder " ,
124+ "Lora not correctly set in unet " ,
125125 )
126126
127127 # We will offload the first adapter in CPU and check if the offloading
@@ -188,7 +188,7 @@ def test_integration_move_lora_dora_cpu(self):
188188
189189 self .assertTrue (
190190 check_if_lora_correctly_set (pipe .unet ),
191- "Lora not correctly set in text encoder " ,
191+ "Lora not correctly set in unet " ,
192192 )
193193
194194 for name , param in pipe .unet .named_parameters ():
@@ -222,6 +222,53 @@ def test_lora_set_adapters_scenarios(self, scenario):
222222 scenario = scenario , expected_atol = expected_atol , expected_rtol = expected_rtol
223223 )
224224
225+ @slow
226+ @require_torch_accelerator
227+ def test_integration_set_lora_device_different_target_layers (self ):
228+ # fixes a bug that occurred when calling set_lora_device with multiple adapters loaded that target different
229+ # layers, see #11833
230+ from peft import LoraConfig
231+
232+ path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
233+ pipe = StableDiffusionPipeline .from_pretrained (path , torch_dtype = torch .float16 )
234+ # configs partly target the same, partly different layers
235+ config0 = LoraConfig (target_modules = ["to_k" , "to_v" ])
236+ config1 = LoraConfig (target_modules = ["to_k" , "to_q" ])
237+ pipe .unet .add_adapter (config0 , adapter_name = "adapter-0" )
238+ pipe .unet .add_adapter (config1 , adapter_name = "adapter-1" )
239+ pipe = pipe .to (torch_device )
240+
241+ self .assertTrue (
242+ check_if_lora_correctly_set (pipe .unet ),
243+ "Lora not correctly set in unet" ,
244+ )
245+
246+ # sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix
247+ modules_adapter_0 = {n for n , _ in pipe .unet .named_modules () if n .endswith (".adapter-0" )}
248+ modules_adapter_1 = {n for n , _ in pipe .unet .named_modules () if n .endswith (".adapter-1" )}
249+ self .assertNotEqual (modules_adapter_0 , modules_adapter_1 )
250+ self .assertTrue (modules_adapter_0 - modules_adapter_1 )
251+ self .assertTrue (modules_adapter_1 - modules_adapter_0 )
252+
253+ # setting both separately works
254+ pipe .set_lora_device (["adapter-0" ], "cpu" )
255+ pipe .set_lora_device (["adapter-1" ], "cpu" )
256+
257+ for name , module in pipe .unet .named_modules ():
258+ if "adapter-0" in name and not isinstance (module , (nn .Dropout , nn .Identity )):
259+ self .assertTrue (module .weight .device == torch .device ("cpu" ))
260+ elif "adapter-1" in name and not isinstance (module , (nn .Dropout , nn .Identity )):
261+ self .assertTrue (module .weight .device == torch .device ("cpu" ))
262+
263+ # setting both at once also works
264+ pipe .set_lora_device (["adapter-0" , "adapter-1" ], torch_device )
265+
266+ for name , module in pipe .unet .named_modules ():
267+ if "adapter-0" in name and not isinstance (module , (nn .Dropout , nn .Identity )):
268+ self .assertTrue (module .weight .device != torch .device ("cpu" ))
269+ elif "adapter-1" in name and not isinstance (module , (nn .Dropout , nn .Identity )):
270+ self .assertTrue (module .weight .device != torch .device ("cpu" ))
271+
225272
226273@slow
227274@nightly
0 commit comments