@@ -68,7 +68,7 @@ def calculate_shift(
6868 return mu
6969
7070
71- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion .retrieve_timesteps
71+ # Copied from diffusers.pipelines.cogview4.pipeline_cogview4 .retrieve_timesteps
7272def retrieve_timesteps (
7373 scheduler ,
7474 num_inference_steps : Optional [int ] = None ,
@@ -100,10 +100,19 @@ def retrieve_timesteps(
100100 `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
101101 second element is the number of inference steps.
102102 """
103+ accepts_timesteps = "timesteps" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
104+ accepts_sigmas = "sigmas" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
105+
103106 if timesteps is not None and sigmas is not None :
104- raise ValueError ("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" )
105- if timesteps is not None :
106- accepts_timesteps = "timesteps" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
107+ if not accepts_timesteps and not accepts_sigmas :
108+ raise ValueError (
109+ f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
110+ f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
111+ )
112+ scheduler .set_timesteps (timesteps = timesteps , sigmas = sigmas , device = device , ** kwargs )
113+ timesteps = scheduler .timesteps
114+ num_inference_steps = len (timesteps )
115+ elif timesteps is not None and sigmas is None :
107116 if not accepts_timesteps :
108117 raise ValueError (
109118 f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
@@ -112,9 +121,8 @@ def retrieve_timesteps(
112121 scheduler .set_timesteps (timesteps = timesteps , device = device , ** kwargs )
113122 timesteps = scheduler .timesteps
114123 num_inference_steps = len (timesteps )
115- elif sigmas is not None :
116- accept_sigmas = "sigmas" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
117- if not accept_sigmas :
124+ elif timesteps is None and sigmas is not None :
125+ if not accepts_sigmas :
118126 raise ValueError (
119127 f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
120128 f" sigmas schedules. Please check whether you are using the correct scheduler."
@@ -515,8 +523,8 @@ def __call__(
515523 The output format of the generate image. Choose between
516524 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
517525 return_dict (`bool`, *optional*, defaults to `True`):
518- Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput `] instead
519- of a plain tuple.
526+ Whether or not to return a [`~pipelines.pipeline_CogView4.CogView4PipelineOutput `] instead of a plain
527+ tuple.
520528 attention_kwargs (`dict`, *optional*):
521529 A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
522530 `self.processor` in
@@ -532,7 +540,6 @@ def __call__(
532540 `._callback_tensor_inputs` attribute of your pipeline class.
533541 max_sequence_length (`int`, defaults to `224`):
534542 Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
535-
536543 Examples:
537544
538545 Returns:
0 commit comments