Skip to content

Commit edbb562

Browse files
committed
Test skip_guidance_layers in pipelines
1 parent 63b631f commit edbb562

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

tests/pipelines/test_pipelines_common.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1718,6 +1718,47 @@ def test_cfg(self):
17181718

17191719
assert out_cfg.shape == out_no_cfg.shape
17201720

1721+
def test_skip_guidance_layers(self):
1722+
sig = inspect.signature(self.pipeline_class.__call__)
1723+
1724+
if "skip_guidance_layers" not in sig.parameters:
1725+
return
1726+
1727+
components = self.get_dummy_components()
1728+
pipe = self.pipeline_class(**components)
1729+
pipe = pipe.to(torch_device)
1730+
pipe.set_progress_bar_config(disable=None)
1731+
1732+
inputs = self.get_dummy_inputs(torch_device)
1733+
1734+
output_full = pipe(**inputs)[0]
1735+
1736+
inputs_with_skip = inputs.copy()
1737+
inputs_with_skip["skip_guidance_layers"] = [0]
1738+
output_skip = pipe(**inputs_with_skip)[0]
1739+
1740+
self.assertFalse(
1741+
np.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
1742+
)
1743+
1744+
self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape")
1745+
1746+
if "num_images_per_prompt" not in sig.parameters:
1747+
return
1748+
1749+
inputs["num_images_per_prompt"] = 2
1750+
output_full = pipe(**inputs)[0]
1751+
1752+
inputs_with_skip = inputs.copy()
1753+
inputs_with_skip["skip_guidance_layers"] = [0]
1754+
output_skip = pipe(**inputs_with_skip)[0]
1755+
1756+
self.assertFalse(
1757+
np.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
1758+
)
1759+
1760+
self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape")
1761+
17211762
def test_callback_inputs(self):
17221763
sig = inspect.signature(self.pipeline_class.__call__)
17231764
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters

0 commit comments

Comments
 (0)