@@ -225,6 +225,39 @@ def test_fused_qkv_projections(self):
225225 original_image_slice , image_slice_disabled , atol = 1e-2 , rtol = 1e-2
226226 ), "Original outputs should match when fused QKV projections are disabled."
227227
228+ def test_skip_guidance_layers (self ):
229+ components = self .get_dummy_components ()
230+ pipe = self .pipeline_class (** components )
231+ pipe = pipe .to (torch_device )
232+ pipe .set_progress_bar_config (disable = None )
233+
234+ inputs = self .get_dummy_inputs (torch_device )
235+
236+ output_full = pipe (** inputs )[0 ]
237+
238+ inputs_with_skip = inputs .copy ()
239+ inputs_with_skip ["skip_guidance_layers" ] = [0 ]
240+ output_skip = pipe (** inputs_with_skip )[0 ]
241+
242+ self .assertFalse (
243+ np .allclose (output_full , output_skip , atol = 1e-5 ), "Outputs should differ when layers are skipped"
244+ )
245+
246+ self .assertEqual (output_full .shape , output_skip .shape , "Outputs should have the same shape" )
247+
248+ inputs ["num_images_per_prompt" ] = 2
249+ output_full = pipe (** inputs )[0 ]
250+
251+ inputs_with_skip = inputs .copy ()
252+ inputs_with_skip ["skip_guidance_layers" ] = [0 ]
253+ output_skip = pipe (** inputs_with_skip )[0 ]
254+
255+ self .assertFalse (
256+ np .allclose (output_full , output_skip , atol = 1e-5 ), "Outputs should differ when layers are skipped"
257+ )
258+
259+ self .assertEqual (output_full .shape , output_skip .shape , "Outputs should have the same shape" )
260+
228261
229262@slow
230263@require_big_gpu_with_torch_cuda
0 commit comments