Skip to content

Commit e0cd2df

Browse files
committed
Test timesteps in PipelineTesterMixin
1 parent 9ff7243 commit e0cd2df

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

tests/pipelines/test_pipelines_common.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1718,6 +1718,27 @@ def test_cfg(self):
17181718

17191719
assert out_cfg.shape == out_no_cfg.shape
17201720

1721+
def test_timesteps(self):
1722+
sig = inspect.signature(self.pipeline_class.__call__)
1723+
1724+
if "timesteps" not in sig.parameters:
1725+
return
1726+
1727+
components = self.get_dummy_components()
1728+
pipe = self.pipeline_class(**components)
1729+
pipe = pipe.to(torch_device)
1730+
pipe.set_progress_bar_config(disable=None)
1731+
1732+
inputs = self.get_dummy_inputs(torch_device)
1733+
1734+
output_without_timesteps = pipe(**inputs)[0]
1735+
1736+
inputs = self.get_dummy_inputs(torch_device)
1737+
inputs["timesteps"] = [499]
1738+
output_with_timesteps = pipe(**inputs)[0]
1739+
max_diff = np.abs(output_without_timesteps - output_with_timesteps).max()
1740+
assert max_diff > 1e-4
1741+
17211742
def test_callback_inputs(self):
17221743
sig = inspect.signature(self.pipeline_class.__call__)
17231744
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters

0 commit comments

Comments
 (0)