@@ -141,33 +141,30 @@ def test_simple_inference_with_text_lora_save_load(self):
141141        strict = False , 
142142    ) 
143143    def  test_lora_fuse_nan (self ):
144-         for  scheduler_cls  in  self .scheduler_classes :
145-             components , text_lora_config , denoiser_lora_config  =  self .get_dummy_components (scheduler_cls )
146-             pipe  =  self .pipeline_class (** components )
147-             pipe  =  pipe .to (torch_device )
148-             pipe .set_progress_bar_config (disable = None )
149-             _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
150- 
151-             if  "text_encoder"  in  self .pipeline_class ._lora_loadable_modules :
152-                 pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
153-                 self .assertTrue (
154-                     check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder" 
155-                 )
156- 
157-             denoiser  =  pipe .transformer  if  self .unet_kwargs  is  None  else  pipe .unet 
158-             denoiser .add_adapter (denoiser_lora_config , "adapter-1" )
159-             self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
160- 
161-             # corrupt one LoRA weight with `inf` values 
162-             with  torch .no_grad ():
163-                 pipe .transformer .layers [0 ].attn .to_q .lora_A ["adapter-1" ].weight  +=  float ("inf" )
164- 
165-             # with `safe_fusing=True` we should see an Error 
166-             with  self .assertRaises (ValueError ):
167-                 pipe .fuse_lora (components = self .pipeline_class ._lora_loadable_modules , safe_fusing = True )
168- 
169-             # without we should not see an error, but every image will be black 
170-             pipe .fuse_lora (components = self .pipeline_class ._lora_loadable_modules , safe_fusing = False )
171-             out  =  pipe (** inputs )[0 ]
172- 
173-             self .assertTrue (np .isnan (out ).all ())
144+         components , text_lora_config , denoiser_lora_config  =  self .get_dummy_components ()
145+         pipe  =  self .pipeline_class (** components )
146+         pipe  =  pipe .to (torch_device )
147+         pipe .set_progress_bar_config (disable = None )
148+         _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
149+ 
150+         if  "text_encoder"  in  self .pipeline_class ._lora_loadable_modules :
151+             pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
152+             self .assertTrue (check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder" )
153+ 
154+         denoiser  =  pipe .transformer  if  self .unet_kwargs  is  None  else  pipe .unet 
155+         denoiser .add_adapter (denoiser_lora_config , "adapter-1" )
156+         self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
157+ 
158+         # corrupt one LoRA weight with `inf` values 
159+         with  torch .no_grad ():
160+             pipe .transformer .layers [0 ].attn .to_q .lora_A ["adapter-1" ].weight  +=  float ("inf" )
161+ 
162+         # with `safe_fusing=True` we should see an Error 
163+         with  self .assertRaises (ValueError ):
164+             pipe .fuse_lora (components = self .pipeline_class ._lora_loadable_modules , safe_fusing = True )
165+ 
166+         # without we should not see an error, but every image will be black 
167+         pipe .fuse_lora (components = self .pipeline_class ._lora_loadable_modules , safe_fusing = False )
168+         out  =  pipe (** inputs )[0 ]
169+ 
170+         self .assertTrue (np .isnan (out ).all ())
0 commit comments