@@ -111,30 +111,6 @@ def get_dummy_inputs(self, with_generator=True):
111111
112112        return  noise , input_ids , pipeline_inputs 
113113
114-     def  get_dummy_tensor_inputs (self , device = None ):
115-         batch_size  =  1 
116-         num_latent_channels  =  4 
117-         num_image_channels  =  3 
118-         height  =  width  =  4 
119-         sequence_length  =  48 
120-         embedding_dim  =  32 
121- 
122-         hidden_states  =  torch .randn ((batch_size , height  *  width , num_latent_channels )).to (torch_device )
123-         encoder_hidden_states  =  torch .randn ((batch_size , sequence_length , embedding_dim )).to (torch_device )
124-         pooled_prompt_embeds  =  torch .randn ((batch_size , embedding_dim )).to (torch_device )
125-         text_ids  =  torch .randn ((sequence_length , num_image_channels )).to (torch_device )
126-         image_ids  =  torch .randn ((height  *  width , num_image_channels )).to (torch_device )
127-         timestep  =  torch .tensor ([1.0 ]).to (torch_device ).expand (batch_size )
128- 
129-         return  {
130-             "hidden_states" : hidden_states ,
131-             "encoder_hidden_states" : encoder_hidden_states ,
132-             "pooled_projections" : pooled_prompt_embeds ,
133-             "txt_ids" : text_ids ,
134-             "img_ids" : image_ids ,
135-             "timestep" : timestep ,
136-         }
137- 
138114    def  test_with_alpha_in_state_dict (self ):
139115        components , _ , denoiser_lora_config  =  self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
140116        pipe  =  self .pipeline_class (** components )
@@ -189,13 +165,12 @@ def test_with_norm_in_state_dict(self):
189165        pipe  =  pipe .to (torch_device )
190166        pipe .set_progress_bar_config (disable = None )
191167
192-         inputs  =  self .get_dummy_tensor_inputs ( torch_device )
168+         _ ,  _ ,  inputs  =  self .get_dummy_inputs ( with_generator = False )
193169
194170        logger  =  logging .get_logger ("diffusers.loaders.lora_pipeline" )
195171        logger .setLevel (logging .INFO )
196172
197-         with  torch .no_grad ():
198-             original_output  =  pipe .transformer (** inputs )[0 ]
173+         original_output  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
199174
200175        for  norm_layer  in  ["norm_q" , "norm_k" , "norm_added_q" , "norm_added_k" ]:
201176            norm_state_dict  =  {}
@@ -206,18 +181,19 @@ def test_with_norm_in_state_dict(self):
206181                    module .weight .shape , device = module .weight .device , dtype = module .weight .dtype 
207182                )
208183
209-             with  torch .no_grad ():
210184                with  CaptureLogger (logger ) as  cap_logger :
211185                    pipe .load_lora_weights (norm_state_dict )
212-                     lora_load_output  =  pipe .transformer (** inputs )[0 ]
186+                     lora_load_output  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
187+ 
213188                self .assertTrue (
214189                    cap_logger .out .startswith (
215190                        "The provided state dict contains normalization layers in addition to LoRA layers" 
216191                    )
217192                )
193+                 self .assertTrue (len (pipe .transformer ._transformer_norm_layers ) >  0 )
218194
219195                pipe .unload_lora_weights ()
220-                 lora_unload_output  =  pipe . transformer (** inputs )[0 ]
196+                 lora_unload_output  =  pipe (** inputs ,  generator = torch . manual_seed ( 0 ) )[0 ]
221197
222198            self .assertTrue (pipe .transformer ._transformer_norm_layers  is  None )
223199            self .assertFalse (np .allclose (original_output , lora_load_output , atol = 1e-5 , rtol = 1e-5 ))
@@ -238,14 +214,11 @@ def test_lora_parameter_expanded_shapes(self):
238214        pipe  =  pipe .to (torch_device )
239215        pipe .set_progress_bar_config (disable = None )
240216
241-         inputs  =  self .get_dummy_tensor_inputs ( torch_device )
217+         _ ,  _ ,  inputs  =  self .get_dummy_inputs ( with_generator = False )
242218
243219        logger  =  logging .get_logger ("diffusers.loaders.lora_pipeline" )
244220        logger .setLevel (logging .DEBUG )
245221
246-         with  torch .no_grad ():
247-             original_output  =  pipe .transformer (** inputs )[0 ]
248- 
249222        out_features , in_features  =  pipe .transformer .x_embedder .weight .shape 
250223        rank  =  4 
251224
@@ -257,12 +230,12 @@ def test_lora_parameter_expanded_shapes(self):
257230        }
258231        with  CaptureLogger (logger ) as  cap_logger :
259232            pipe .load_lora_weights (lora_state_dict , "adapter-1" )
260-         inputs ["hidden_states" ] =  torch .cat ([inputs ["hidden_states" ]] *  2 , dim = 2 )
261-         with  torch .no_grad ():
262-             expanded_output  =  pipe .transformer (** inputs )[0 ]
233+ 
234+         self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] ==  2  *  in_features )
235+         self .assertTrue (pipe .transformer .config .in_channels  ==  2  *  in_features )
236+ 
263237        pipe .delete_adapters ("adapter-1" )
264238        self .assertTrue (cap_logger .out .startswith ("Expanding the nn.Linear input/output features for module" ))
265-         self .assertFalse (np .allclose (original_output , expanded_output , atol = 1e-3 , rtol = 1e-3 ))
266239
267240        components , _ , _  =  self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
268241        pipe  =  self .pipeline_class (** components )
@@ -286,24 +259,21 @@ def test_lora_B_bias(self):
286259        pipe  =  pipe .to (torch_device )
287260        pipe .set_progress_bar_config (disable = None )
288261
289-         inputs  =  self .get_dummy_tensor_inputs ( torch_device )
262+         _ ,  _ ,  inputs  =  self .get_dummy_inputs ( with_generator = False )
290263
291264        logger  =  logging .get_logger ("diffusers.loaders.lora_pipeline" )
292265        logger .setLevel (logging .INFO )
293266
294-         with  torch .no_grad ():
295-             original_output  =  pipe .transformer (** inputs )[0 ]
267+         original_output  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
296268
297269        denoiser_lora_config .lora_bias  =  False 
298270        pipe .transformer .add_adapter (denoiser_lora_config , "adapter-1" )
299-         with  torch .no_grad ():
300-             lora_bias_false_output  =  pipe .transformer (** inputs )[0 ]
271+         lora_bias_false_output  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
301272        pipe .delete_adapters ("adapter-1" )
302273
303274        denoiser_lora_config .lora_bias  =  True 
304275        pipe .transformer .add_adapter (denoiser_lora_config , "adapter-1" )
305-         with  torch .no_grad ():
306-             lora_bias_true_output  =  pipe .transformer (** inputs )[0 ]
276+         lora_bias_true_output  =  pipe (** inputs )[0 ]
307277
308278        self .assertFalse (np .allclose (original_output , lora_bias_false_output , atol = 1e-3 , rtol = 1e-3 ))
309279        self .assertFalse (np .allclose (original_output , lora_bias_true_output , atol = 1e-3 , rtol = 1e-3 ))
0 commit comments