File tree Expand file tree Collapse file tree 1 file changed +21
-0
lines changed Expand file tree Collapse file tree 1 file changed +21
-0
lines changed Original file line number Diff line number Diff line change @@ -1718,6 +1718,27 @@ def test_cfg(self):
17181718
17191719 assert out_cfg .shape == out_no_cfg .shape
17201720
1721+ def test_timesteps (self ):
1722+ sig = inspect .signature (self .pipeline_class .__call__ )
1723+
1724+ if "timesteps" not in sig .parameters :
1725+ return
1726+
1727+ components = self .get_dummy_components ()
1728+ pipe = self .pipeline_class (** components )
1729+ pipe = pipe .to (torch_device )
1730+ pipe .set_progress_bar_config (disable = None )
1731+
1732+ inputs = self .get_dummy_inputs (torch_device )
1733+
1734+ output_without_timesteps = pipe (** inputs )[0 ]
1735+
1736+ inputs = self .get_dummy_inputs (torch_device )
1737+ inputs ["timesteps" ] = [499 ]
1738+ output_with_timesteps = pipe (** inputs )[0 ]
1739+ max_diff = np .abs (output_without_timesteps - output_with_timesteps ).max ()
1740+ assert max_diff > 1e-4
1741+
17211742 def test_callback_inputs (self ):
17221743 sig = inspect .signature (self .pipeline_class .__call__ )
17231744 has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig .parameters
You can’t perform that action at this time.
0 commit comments