@@ -205,3 +205,85 @@ def test_simple_inference_with_transformer_lora_and_scale(self):
205
205
np .allclose (output_no_lora , output_lora_0_scale , atol = 1e-3 , rtol = 1e-3 ),
206
206
"Lora + 0 scale should lead to same result as no LoRA" ,
207
207
)
208
+
209
+ def test_simple_inference_with_transformer_fused (self ):
210
+ components = self .get_dummy_components ()
211
+ transformer_lora_config = self .get_lora_config_for_transformer ()
212
+ pipe = self .pipeline_class (** components )
213
+ pipe = pipe .to (torch_device )
214
+ pipe .set_progress_bar_config (disable = None )
215
+
216
+ inputs = self .get_dummy_inputs (torch_device )
217
+ output_no_lora = pipe (** inputs ).images
218
+
219
+ pipe .transformer .add_adapter (transformer_lora_config )
220
+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in transformer" )
221
+
222
+ pipe .fuse_lora ()
223
+ # Fusing should still keep the LoRA layers
224
+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in transformer" )
225
+
226
+ inputs = self .get_dummy_inputs (torch_device )
227
+ ouput_fused = pipe (** inputs ).images
228
+ self .assertFalse (
229
+ np .allclose (ouput_fused , output_no_lora , atol = 1e-3 , rtol = 1e-3 ), "Fused lora should change the output"
230
+ )
231
+
232
+ def test_simple_inference_with_transformer_fused_with_no_fusion (self ):
233
+ components = self .get_dummy_components ()
234
+ transformer_lora_config = self .get_lora_config_for_transformer ()
235
+ pipe = self .pipeline_class (** components )
236
+ pipe = pipe .to (torch_device )
237
+ pipe .set_progress_bar_config (disable = None )
238
+
239
+ inputs = self .get_dummy_inputs (torch_device )
240
+ output_no_lora = pipe (** inputs ).images
241
+
242
+ pipe .transformer .add_adapter (transformer_lora_config )
243
+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in transformer" )
244
+ inputs = self .get_dummy_inputs (torch_device )
245
+ ouput_lora = pipe (** inputs ).images
246
+
247
+ pipe .fuse_lora ()
248
+ # Fusing should still keep the LoRA layers
249
+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in transformer" )
250
+
251
+ inputs = self .get_dummy_inputs (torch_device )
252
+ ouput_fused = pipe (** inputs ).images
253
+ self .assertFalse (
254
+ np .allclose (ouput_fused , output_no_lora , atol = 1e-3 , rtol = 1e-3 ), "Fused lora should change the output"
255
+ )
256
+ self .assertTrue (
257
+ np .allclose (ouput_fused , ouput_lora , atol = 1e-3 , rtol = 1e-3 ),
258
+ "Fused lora output should be changed when LoRA isn't fused but still effective." ,
259
+ )
260
+
261
+ def test_simple_inference_with_transformer_fuse_unfuse (self ):
262
+ components = self .get_dummy_components ()
263
+ transformer_lora_config = self .get_lora_config_for_transformer ()
264
+ pipe = self .pipeline_class (** components )
265
+ pipe = pipe .to (torch_device )
266
+ pipe .set_progress_bar_config (disable = None )
267
+
268
+ inputs = self .get_dummy_inputs (torch_device )
269
+ output_no_lora = pipe (** inputs ).images
270
+
271
+ pipe .transformer .add_adapter (transformer_lora_config )
272
+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in transformer" )
273
+
274
+ pipe .fuse_lora ()
275
+ # Fusing should still keep the LoRA layers
276
+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in transformer" )
277
+ inputs = self .get_dummy_inputs (torch_device )
278
+ ouput_fused = pipe (** inputs ).images
279
+ self .assertFalse (
280
+ np .allclose (ouput_fused , output_no_lora , atol = 1e-3 , rtol = 1e-3 ), "Fused lora should change the output"
281
+ )
282
+
283
+ pipe .unfuse_lora ()
284
+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in transformer" )
285
+ inputs = self .get_dummy_inputs (torch_device )
286
+ output_unfused_lora = pipe (** inputs ).images
287
+ self .assertTrue (
288
+ np .allclose (ouput_fused , output_unfused_lora , atol = 1e-3 , rtol = 1e-3 ), "Fused lora should change the output"
289
+ )
0 commit comments