@@ -225,6 +225,39 @@ def test_fused_qkv_projections(self):
225225            original_image_slice , image_slice_disabled , atol = 1e-2 , rtol = 1e-2 
226226        ), "Original outputs should match when fused QKV projections are disabled." 
227227
228+     def  test_skip_guidance_layers (self ):
229+         components  =  self .get_dummy_components ()
230+         pipe  =  self .pipeline_class (** components )
231+         pipe  =  pipe .to (torch_device )
232+         pipe .set_progress_bar_config (disable = None )
233+ 
234+         inputs  =  self .get_dummy_inputs (torch_device )
235+ 
236+         output_full  =  pipe (** inputs )[0 ]
237+ 
238+         inputs_with_skip  =  inputs .copy ()
239+         inputs_with_skip ["skip_guidance_layers" ] =  [0 ]
240+         output_skip  =  pipe (** inputs_with_skip )[0 ]
241+ 
242+         self .assertFalse (
243+             np .allclose (output_full , output_skip , atol = 1e-5 ), "Outputs should differ when layers are skipped" 
244+         )
245+ 
246+         self .assertEqual (output_full .shape , output_skip .shape , "Outputs should have the same shape" )
247+ 
248+         inputs ["num_images_per_prompt" ] =  2 
249+         output_full  =  pipe (** inputs )[0 ]
250+ 
251+         inputs_with_skip  =  inputs .copy ()
252+         inputs_with_skip ["skip_guidance_layers" ] =  [0 ]
253+         output_skip  =  pipe (** inputs_with_skip )[0 ]
254+ 
255+         self .assertFalse (
256+             np .allclose (output_full , output_skip , atol = 1e-5 ), "Outputs should differ when layers are skipped" 
257+         )
258+ 
259+         self .assertEqual (output_full .shape , output_skip .shape , "Outputs should have the same shape" )
260+ 
228261
229262@slow  
230263@require_big_gpu_with_torch_cuda  
0 commit comments