@@ -120,7 +120,7 @@ def test_integration_move_lora_cpu(self):
120120
121121 self .assertTrue (
122122 check_if_lora_correctly_set (pipe .unet ),
123- "Lora not correctly set in text encoder " ,
123+ "Lora not correctly set in unet " ,
124124 )
125125
126126 # We will offload the first adapter in CPU and check if the offloading
@@ -187,7 +187,7 @@ def test_integration_move_lora_dora_cpu(self):
187187
188188 self .assertTrue (
189189 check_if_lora_correctly_set (pipe .unet ),
190- "Lora not correctly set in text encoder " ,
190+ "Lora not correctly set in unet " ,
191191 )
192192
193193 for name , param in pipe .unet .named_parameters ():
@@ -208,6 +208,53 @@ def test_integration_move_lora_dora_cpu(self):
208208 if "lora_" in name :
209209 self .assertNotEqual (param .device , torch .device ("cpu" ))
210210
211+ @slow
212+ @require_torch_accelerator
213+ def test_integration_set_lora_device_different_target_layers (self ):
214+ # fixes a bug that occurred when calling set_lora_device with multiple adapters loaded that target different
215+ # layers, see #11833
216+ from peft import LoraConfig
217+
218+ path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
219+ pipe = StableDiffusionPipeline .from_pretrained (path , torch_dtype = torch .float16 )
220+ # configs partly target the same, partly different layers
221+ config0 = LoraConfig (target_modules = ["to_k" , "to_v" ])
222+ config1 = LoraConfig (target_modules = ["to_k" , "to_q" ])
223+ pipe .unet .add_adapter (config0 , adapter_name = "adapter-0" )
224+ pipe .unet .add_adapter (config1 , adapter_name = "adapter-1" )
225+ pipe = pipe .to (torch_device )
226+
227+ self .assertTrue (
228+ check_if_lora_correctly_set (pipe .unet ),
229+ "Lora not correctly set in unet" ,
230+ )
231+
232+ # sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix
233+ modules_adapter_0 = {n for n , _ in pipe .unet .named_modules () if n .endswith (".adapter-0" )}
234+ modules_adapter_1 = {n for n , _ in pipe .unet .named_modules () if n .endswith (".adapter-1" )}
235+ self .assertNotEqual (modules_adapter_0 , modules_adapter_1 )
236+ self .assertTrue (modules_adapter_0 - modules_adapter_1 )
237+ self .assertTrue (modules_adapter_1 - modules_adapter_0 )
238+
239+ # setting both separately works
240+ pipe .set_lora_device (["adapter-0" ], "cpu" )
241+ pipe .set_lora_device (["adapter-1" ], "cpu" )
242+
243+ for name , module in pipe .unet .named_modules ():
244+ if "adapter-0" in name and not isinstance (module , (nn .Dropout , nn .Identity )):
245+ self .assertTrue (module .weight .device == torch .device ("cpu" ))
246+ elif "adapter-1" in name and not isinstance (module , (nn .Dropout , nn .Identity )):
247+ self .assertTrue (module .weight .device == torch .device ("cpu" ))
248+
249+ # setting both at once also works
250+ pipe .set_lora_device (["adapter-0" , "adapter-1" ], torch_device )
251+
252+ for name , module in pipe .unet .named_modules ():
253+ if "adapter-0" in name and not isinstance (module , (nn .Dropout , nn .Identity )):
254+ self .assertTrue (module .weight .device != torch .device ("cpu" ))
255+ elif "adapter-1" in name and not isinstance (module , (nn .Dropout , nn .Identity )):
256+ self .assertTrue (module .weight .device != torch .device ("cpu" ))
257+
211258
212259@slow
213260@nightly
0 commit comments