@@ -1718,6 +1718,47 @@ def test_cfg(self):
17181718
17191719        assert  out_cfg .shape  ==  out_no_cfg .shape 
17201720
1721+     def  test_skip_guidance_layers (self ):
1722+         sig  =  inspect .signature (self .pipeline_class .__call__ )
1723+ 
1724+         if  "skip_guidance_layers"  not  in sig .parameters :
1725+             return 
1726+ 
1727+         components  =  self .get_dummy_components ()
1728+         pipe  =  self .pipeline_class (** components )
1729+         pipe  =  pipe .to (torch_device )
1730+         pipe .set_progress_bar_config (disable = None )
1731+ 
1732+         inputs  =  self .get_dummy_inputs (torch_device )
1733+ 
1734+         output_full  =  pipe (** inputs )[0 ]
1735+ 
1736+         inputs_with_skip  =  inputs .copy ()
1737+         inputs_with_skip ["skip_guidance_layers" ] =  [0 ]
1738+         output_skip  =  pipe (** inputs_with_skip )[0 ]
1739+ 
1740+         self .assertFalse (
1741+             np .allclose (output_full , output_skip , atol = 1e-5 ), "Outputs should differ when layers are skipped" 
1742+         )
1743+ 
1744+         self .assertEqual (output_full .shape , output_skip .shape , "Outputs should have the same shape" )
1745+ 
1746+         if  "num_images_per_prompt"  not  in sig .parameters :
1747+             return 
1748+ 
1749+         inputs ["num_images_per_prompt" ] =  2 
1750+         output_full  =  pipe (** inputs )[0 ]
1751+ 
1752+         inputs_with_skip  =  inputs .copy ()
1753+         inputs_with_skip ["skip_guidance_layers" ] =  [0 ]
1754+         output_skip  =  pipe (** inputs_with_skip )[0 ]
1755+ 
1756+         self .assertFalse (
1757+             np .allclose (output_full , output_skip , atol = 1e-5 ), "Outputs should differ when layers are skipped" 
1758+         )
1759+ 
1760+         self .assertEqual (output_full .shape , output_skip .shape , "Outputs should have the same shape" )
1761+ 
17211762    def  test_callback_inputs (self ):
17221763        sig  =  inspect .signature (self .pipeline_class .__call__ )
17231764        has_callback_tensor_inputs  =  "callback_on_step_end_tensor_inputs"  in  sig .parameters 
0 commit comments