Skip to content

Commit 39e1198

Browse files
committed
use flow match scheduler instead of custom
1 parent 2046cf2 commit 39e1198

File tree

9 files changed

+136
-329
lines changed

9 files changed

+136
-329
lines changed

scripts/convert_cogview4_to_diffusers.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from accelerate import init_empty_weights
3333
from transformers import GlmForCausalLM, PreTrainedTokenizerFast
3434

35-
from diffusers import AutoencoderKL, CogView4DDIMScheduler, CogView4Pipeline, CogView4Transformer2DModel
35+
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
3636
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
3737
from diffusers.utils.import_utils import is_accelerate_available
3838

@@ -222,19 +222,8 @@ def main(args):
222222
for param in text_encoder.parameters():
223223
param.data = param.data.contiguous()
224224

225-
scheduler = CogView4DDIMScheduler.from_config(
226-
{
227-
"shift_scale": 1.0,
228-
"beta_end": 0.012,
229-
"beta_schedule": "scaled_linear",
230-
"beta_start": 0.00085,
231-
"clip_sample": False,
232-
"num_train_timesteps": 1000,
233-
"prediction_type": "v_prediction",
234-
"rescale_betas_zero_snr": True,
235-
"set_alpha_to_one": True,
236-
"timestep_spacing": "linspace",
237-
}
225+
scheduler = FlowMatchEulerDiscreteScheduler(
226+
base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type="linear"
238227
)
239228

240229
pipe = CogView4Pipeline(

scripts/convert_cogview4_to_diffusers_megatron.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,7 @@
2727
from tqdm import tqdm
2828
from transformers import GlmForCausalLM, PreTrainedTokenizerFast
2929

30-
from diffusers import (
31-
AutoencoderKL,
32-
CogView4DDIMScheduler,
33-
CogView4Pipeline,
34-
CogView4Transformer2DModel,
35-
)
30+
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
3631
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
3732

3833

@@ -345,19 +340,8 @@ def main(args):
345340
param.data = param.data.contiguous()
346341

347342
# Initialize the scheduler
348-
scheduler = CogView4DDIMScheduler.from_config(
349-
{
350-
"shift_scale": 1.0,
351-
"beta_end": 0.012,
352-
"beta_schedule": "scaled_linear",
353-
"beta_start": 0.00085,
354-
"clip_sample": False,
355-
"num_train_timesteps": 1000,
356-
"prediction_type": "v_prediction",
357-
"rescale_betas_zero_snr": True,
358-
"set_alpha_to_one": True,
359-
"timestep_spacing": "linspace",
360-
}
343+
scheduler = FlowMatchEulerDiscreteScheduler(
344+
base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type="linear"
361345
)
362346

363347
# Create the pipeline

src/diffusers/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@
188188
"CMStochasticIterativeScheduler",
189189
"CogVideoXDDIMScheduler",
190190
"CogVideoXDPMScheduler",
191-
"CogView4DDIMScheduler",
192191
"DDIMInverseScheduler",
193192
"DDIMParallelScheduler",
194193
"DDIMScheduler",
@@ -707,7 +706,6 @@
707706
CMStochasticIterativeScheduler,
708707
CogVideoXDDIMScheduler,
709708
CogVideoXDPMScheduler,
710-
CogView4DDIMScheduler,
711709
DDIMInverseScheduler,
712710
DDIMParallelScheduler,
713711
DDIMScheduler,

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 107 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,18 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import inspect
1617
from typing import Callable, Dict, List, Optional, Tuple, Union
1718

19+
import numpy as np
1820
import torch
1921
from transformers import AutoTokenizer, GlmModel
2022

2123
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2224
from ...image_processor import VaeImageProcessor
2325
from ...models import AutoencoderKL, CogView4Transformer2DModel
2426
from ...pipelines.pipeline_utils import DiffusionPipeline
25-
from ...schedulers import CogView4DDIMScheduler
27+
from ...schedulers import FlowMatchEulerDiscreteScheduler
2628
from ...utils import is_torch_xla_available, logging, replace_example_docstring
2729
from ...utils.torch_utils import randn_tensor
2830
from .pipeline_output import CogView4PipelineOutput
@@ -53,6 +55,82 @@
5355
"""
5456

5557

58+
def calculate_shift(
59+
image_seq_len,
60+
base_seq_len: int = 256,
61+
base_shift: float = 0.25,
62+
max_shift: float = 0.75,
63+
):
64+
# m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
65+
# b = base_shift - m * base_seq_len
66+
# mu = image_seq_len * m + b
67+
# return mu
68+
69+
m = (image_seq_len / base_seq_len) ** 0.5
70+
mu = m * max_shift + base_shift
71+
return mu
72+
73+
74+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
75+
def retrieve_timesteps(
76+
scheduler,
77+
num_inference_steps: Optional[int] = None,
78+
device: Optional[Union[str, torch.device]] = None,
79+
timesteps: Optional[List[int]] = None,
80+
sigmas: Optional[List[float]] = None,
81+
**kwargs,
82+
):
83+
r"""
84+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
85+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
86+
87+
Args:
88+
scheduler (`SchedulerMixin`):
89+
The scheduler to get timesteps from.
90+
num_inference_steps (`int`):
91+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
92+
must be `None`.
93+
device (`str` or `torch.device`, *optional*):
94+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
95+
timesteps (`List[int]`, *optional*):
96+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
97+
`num_inference_steps` and `sigmas` must be `None`.
98+
sigmas (`List[float]`, *optional*):
99+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
100+
`num_inference_steps` and `timesteps` must be `None`.
101+
102+
Returns:
103+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
104+
second element is the number of inference steps.
105+
"""
106+
if timesteps is not None and sigmas is not None:
107+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
108+
if timesteps is not None:
109+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
110+
if not accepts_timesteps:
111+
raise ValueError(
112+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
113+
f" timestep schedules. Please check whether you are using the correct scheduler."
114+
)
115+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
116+
timesteps = scheduler.timesteps
117+
num_inference_steps = len(timesteps)
118+
elif sigmas is not None:
119+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
120+
if not accept_sigmas:
121+
raise ValueError(
122+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
123+
f" sigmas schedules. Please check whether you are using the correct scheduler."
124+
)
125+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
126+
timesteps = scheduler.timesteps
127+
num_inference_steps = len(timesteps)
128+
else:
129+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
130+
timesteps = scheduler.timesteps
131+
return timesteps, num_inference_steps
132+
133+
56134
class CogView4Pipeline(DiffusionPipeline):
57135
r"""
58136
Pipeline for text-to-image generation using CogView4.
@@ -86,7 +164,7 @@ def __init__(
86164
text_encoder: GlmModel,
87165
vae: AutoencoderKL,
88166
transformer: CogView4Transformer2DModel,
89-
scheduler: CogView4DDIMScheduler,
167+
scheduler: FlowMatchEulerDiscreteScheduler,
90168
):
91169
super().__init__()
92170

@@ -219,8 +297,10 @@ def encode_prompt(
219297

220298
return prompt_embeds, negative_prompt_embeds
221299

222-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
223300
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
301+
if latents is not None:
302+
return latents.to(device)
303+
224304
shape = (
225305
batch_size,
226306
num_channels_latents,
@@ -232,14 +312,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
232312
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
233313
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
234314
)
235-
236-
if latents is None:
237-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
238-
else:
239-
latents = latents.to(device)
240-
241-
# scale the initial noise by the standard deviation required by the scheduler
242-
latents = latents * self.scheduler.init_noise_sigma
315+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
243316
return latents
244317

245318
def check_inputs(
@@ -322,6 +395,7 @@ def __call__(
322395
width: Optional[int] = None,
323396
num_inference_steps: int = 50,
324397
timesteps: Optional[List[int]] = None,
398+
sigmas: Optional[List[float]] = None,
325399
guidance_scale: float = 5.0,
326400
num_images_per_prompt: int = 1,
327401
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -359,6 +433,10 @@ def __call__(
359433
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
360434
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
361435
passed will be used. Must be in descending order.
436+
sigmas (`List[float]`, *optional*):
437+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
438+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
439+
will be used.
362440
guidance_scale (`float`, *optional*, defaults to `5.0`):
363441
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
364442
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -491,9 +569,22 @@ def __call__(
491569
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
492570
self.transformer.config.patch_size**2
493571
)
494-
self.scheduler.set_timesteps(num_inference_steps, image_seq_len, device)
495-
timesteps = self.scheduler.timesteps
496-
self._num_timesteps = len(timesteps)
572+
573+
timesteps = (
574+
np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps)
575+
if timesteps is None
576+
else np.array(timesteps)
577+
)
578+
timesteps = timesteps.astype(np.int64)
579+
sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
580+
mu = calculate_shift(
581+
image_seq_len,
582+
self.scheduler.config.get("base_image_seq_len", 256),
583+
self.scheduler.config.get("base_shift", 0.25),
584+
self.scheduler.config.get("max_shift", 0.75),
585+
)
586+
_, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
587+
timesteps = torch.from_numpy(timesteps).to(device)
497588

498589
# Denoising loop
499590
transformer_dtype = self.transformer.dtype
@@ -504,8 +595,7 @@ def __call__(
504595
if self.interrupt:
505596
continue
506597

507-
latent_model_input = self.scheduler.scale_model_input(latents, t)
508-
latent_model_input = latent_model_input.to(transformer_dtype)
598+
latent_model_input = latents.to(transformer_dtype)
509599

510600
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
511601
timestep = t.expand(latents.shape[0])
@@ -536,7 +626,7 @@ def __call__(
536626
else:
537627
noise_pred = noise_pred_cond
538628

539-
latents = self.scheduler.step(noise_pred, latents, t).prev_sample
629+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
540630

541631
# call the callback, if provided
542632
if callback_on_step_end is not None:

src/diffusers/schedulers/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
4545
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
4646
_import_structure["scheduling_ddim_cogvideox"] = ["CogVideoXDDIMScheduler"]
47-
_import_structure["scheduling_ddim_cogview4"] = ["CogView4DDIMScheduler"]
4847
_import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"]
4948
_import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"]
5049
_import_structure["scheduling_ddpm"] = ["DDPMScheduler"]
@@ -145,7 +144,6 @@
145144
from .scheduling_consistency_models import CMStochasticIterativeScheduler
146145
from .scheduling_ddim import DDIMScheduler
147146
from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler
148-
from .scheduling_ddim_cogview4 import CogView4DDIMScheduler
149147
from .scheduling_ddim_inverse import DDIMInverseScheduler
150148
from .scheduling_ddim_parallel import DDIMParallelScheduler
151149
from .scheduling_ddpm import DDPMScheduler

0 commit comments

Comments
 (0)