@@ -1718,6 +1718,47 @@ def test_cfg(self):
17181718
17191719 assert out_cfg .shape == out_no_cfg .shape
17201720
1721+ def test_skip_guidance_layers (self ):
1722+ sig = inspect .signature (self .pipeline_class .__call__ )
1723+
1724+ if "skip_guidance_layers" not in sig .parameters :
1725+ return
1726+
1727+ components = self .get_dummy_components ()
1728+ pipe = self .pipeline_class (** components )
1729+ pipe = pipe .to (torch_device )
1730+ pipe .set_progress_bar_config (disable = None )
1731+
1732+ inputs = self .get_dummy_inputs (torch_device )
1733+
1734+ output_full = pipe (** inputs )[0 ]
1735+
1736+ inputs_with_skip = inputs .copy ()
1737+ inputs_with_skip ["skip_guidance_layers" ] = [0 ]
1738+ output_skip = pipe (** inputs_with_skip )[0 ]
1739+
1740+ self .assertFalse (
1741+ np .allclose (output_full , output_skip , atol = 1e-5 ), "Outputs should differ when layers are skipped"
1742+ )
1743+
1744+ self .assertEqual (output_full .shape , output_skip .shape , "Outputs should have the same shape" )
1745+
1746+ if "num_images_per_prompt" not in sig .parameters :
1747+ return
1748+
1749+ inputs ["num_images_per_prompt" ] = 2
1750+ output_full = pipe (** inputs )[0 ]
1751+
1752+ inputs_with_skip = inputs .copy ()
1753+ inputs_with_skip ["skip_guidance_layers" ] = [0 ]
1754+ output_skip = pipe (** inputs_with_skip )[0 ]
1755+
1756+ self .assertFalse (
1757+ np .allclose (output_full , output_skip , atol = 1e-5 ), "Outputs should differ when layers are skipped"
1758+ )
1759+
1760+ self .assertEqual (output_full .shape , output_skip .shape , "Outputs should have the same shape" )
1761+
17211762 def test_callback_inputs (self ):
17221763 sig = inspect .signature (self .pipeline_class .__call__ )
17231764 has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig .parameters
0 commit comments