1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ import inspect
1617from typing import Callable , Dict , List , Optional , Tuple , Union
1718
19+ import numpy as np
1820import torch
1921from transformers import AutoTokenizer , GlmModel
2022
2123from ...callbacks import MultiPipelineCallbacks , PipelineCallback
2224from ...image_processor import VaeImageProcessor
2325from ...models import AutoencoderKL , CogView4Transformer2DModel
2426from ...pipelines .pipeline_utils import DiffusionPipeline
25- from ...schedulers import CogView4DDIMScheduler
27+ from ...schedulers import FlowMatchEulerDiscreteScheduler
2628from ...utils import is_torch_xla_available , logging , replace_example_docstring
2729from ...utils .torch_utils import randn_tensor
2830from .pipeline_output import CogView4PipelineOutput
5355"""
5456
5557
58+ def calculate_shift (
59+ image_seq_len ,
60+ base_seq_len : int = 256 ,
61+ base_shift : float = 0.25 ,
62+ max_shift : float = 0.75 ,
63+ ):
64+ # m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
65+ # b = base_shift - m * base_seq_len
66+ # mu = image_seq_len * m + b
67+ # return mu
68+
69+ m = (image_seq_len / base_seq_len ) ** 0.5
70+ mu = m * max_shift + base_shift
71+ return mu
72+
73+
74+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
75+ def retrieve_timesteps (
76+ scheduler ,
77+ num_inference_steps : Optional [int ] = None ,
78+ device : Optional [Union [str , torch .device ]] = None ,
79+ timesteps : Optional [List [int ]] = None ,
80+ sigmas : Optional [List [float ]] = None ,
81+ ** kwargs ,
82+ ):
83+ r"""
84+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
85+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
86+
87+ Args:
88+ scheduler (`SchedulerMixin`):
89+ The scheduler to get timesteps from.
90+ num_inference_steps (`int`):
91+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
92+ must be `None`.
93+ device (`str` or `torch.device`, *optional*):
94+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
95+ timesteps (`List[int]`, *optional*):
96+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
97+ `num_inference_steps` and `sigmas` must be `None`.
98+ sigmas (`List[float]`, *optional*):
99+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
100+ `num_inference_steps` and `timesteps` must be `None`.
101+
102+ Returns:
103+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
104+ second element is the number of inference steps.
105+ """
106+ if timesteps is not None and sigmas is not None :
107+ raise ValueError ("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" )
108+ if timesteps is not None :
109+ accepts_timesteps = "timesteps" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
110+ if not accepts_timesteps :
111+ raise ValueError (
112+ f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
113+ f" timestep schedules. Please check whether you are using the correct scheduler."
114+ )
115+ scheduler .set_timesteps (timesteps = timesteps , device = device , ** kwargs )
116+ timesteps = scheduler .timesteps
117+ num_inference_steps = len (timesteps )
118+ elif sigmas is not None :
119+ accept_sigmas = "sigmas" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
120+ if not accept_sigmas :
121+ raise ValueError (
122+ f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
123+ f" sigmas schedules. Please check whether you are using the correct scheduler."
124+ )
125+ scheduler .set_timesteps (sigmas = sigmas , device = device , ** kwargs )
126+ timesteps = scheduler .timesteps
127+ num_inference_steps = len (timesteps )
128+ else :
129+ scheduler .set_timesteps (num_inference_steps , device = device , ** kwargs )
130+ timesteps = scheduler .timesteps
131+ return timesteps , num_inference_steps
132+
133+
56134class CogView4Pipeline (DiffusionPipeline ):
57135 r"""
58136 Pipeline for text-to-image generation using CogView4.
@@ -86,7 +164,7 @@ def __init__(
86164 text_encoder : GlmModel ,
87165 vae : AutoencoderKL ,
88166 transformer : CogView4Transformer2DModel ,
89- scheduler : CogView4DDIMScheduler ,
167+ scheduler : FlowMatchEulerDiscreteScheduler ,
90168 ):
91169 super ().__init__ ()
92170
@@ -219,8 +297,10 @@ def encode_prompt(
219297
220298 return prompt_embeds , negative_prompt_embeds
221299
222- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
223300 def prepare_latents (self , batch_size , num_channels_latents , height , width , dtype , device , generator , latents = None ):
301+ if latents is not None :
302+ return latents .to (device )
303+
224304 shape = (
225305 batch_size ,
226306 num_channels_latents ,
@@ -232,14 +312,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
232312 f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
233313 f" size of { batch_size } . Make sure the batch size matches the length of the generators."
234314 )
235-
236- if latents is None :
237- latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
238- else :
239- latents = latents .to (device )
240-
241- # scale the initial noise by the standard deviation required by the scheduler
242- latents = latents * self .scheduler .init_noise_sigma
315+ latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
243316 return latents
244317
245318 def check_inputs (
@@ -322,6 +395,7 @@ def __call__(
322395 width : Optional [int ] = None ,
323396 num_inference_steps : int = 50 ,
324397 timesteps : Optional [List [int ]] = None ,
398+ sigmas : Optional [List [float ]] = None ,
325399 guidance_scale : float = 5.0 ,
326400 num_images_per_prompt : int = 1 ,
327401 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
@@ -359,6 +433,10 @@ def __call__(
359433 Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
360434 in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
361435 passed will be used. Must be in descending order.
436+ sigmas (`List[float]`, *optional*):
437+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
438+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
439+ will be used.
362440 guidance_scale (`float`, *optional*, defaults to `5.0`):
363441 Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
364442 `guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -491,9 +569,22 @@ def __call__(
491569 image_seq_len = ((height // self .vae_scale_factor ) * (width // self .vae_scale_factor )) // (
492570 self .transformer .config .patch_size ** 2
493571 )
494- self .scheduler .set_timesteps (num_inference_steps , image_seq_len , device )
495- timesteps = self .scheduler .timesteps
496- self ._num_timesteps = len (timesteps )
572+
573+ timesteps = (
574+ np .linspace (self .scheduler .config .num_train_timesteps , 1.0 , num_inference_steps )
575+ if timesteps is None
576+ else np .array (timesteps )
577+ )
578+ timesteps = timesteps .astype (np .int64 )
579+ sigmas = timesteps / self .scheduler .config .num_train_timesteps if sigmas is None else sigmas
580+ mu = calculate_shift (
581+ image_seq_len ,
582+ self .scheduler .config .get ("base_image_seq_len" , 256 ),
583+ self .scheduler .config .get ("base_shift" , 0.25 ),
584+ self .scheduler .config .get ("max_shift" , 0.75 ),
585+ )
586+ _ , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , sigmas = sigmas , mu = mu )
587+ timesteps = torch .from_numpy (timesteps ).to (device )
497588
498589 # Denoising loop
499590 transformer_dtype = self .transformer .dtype
@@ -504,8 +595,7 @@ def __call__(
504595 if self .interrupt :
505596 continue
506597
507- latent_model_input = self .scheduler .scale_model_input (latents , t )
508- latent_model_input = latent_model_input .to (transformer_dtype )
598+ latent_model_input = latents .to (transformer_dtype )
509599
510600 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
511601 timestep = t .expand (latents .shape [0 ])
@@ -536,7 +626,7 @@ def __call__(
536626 else :
537627 noise_pred = noise_pred_cond
538628
539- latents = self .scheduler .step (noise_pred , latents , t ). prev_sample
629+ latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[ 0 ]
540630
541631 # call the callback, if provided
542632 if callback_on_step_end is not None :
0 commit comments