Skip to content

Commit f55e3cc

Browse files
use cogview4 pipeline with timestep
1 parent b4e11e7 commit f55e3cc

File tree

2 files changed

+47
-20
lines changed

2 files changed

+47
-20
lines changed

examples/cogview4-control/train_control_cogview4.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@
3737
from tqdm.auto import tqdm
3838

3939
import diffusers
40-
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, CogView4ControlPipeline,CogView4Transformer2DModel
40+
from diffusers import (
41+
AutoencoderKL,
42+
FlowMatchEulerDiscreteScheduler,
43+
CogView4ControlPipeline,
44+
CogView4Transformer2DModel,
45+
)
4146
from diffusers.optimization import get_scheduler
4247
from diffusers.training_utils import (
4348
compute_density_for_timestep_sampling,
@@ -787,7 +792,7 @@ def main(args):
787792

788793
# enable image inputs
789794
with torch.no_grad():
790-
patch_size = cogview4_transformer.config.patch_size
795+
patch_size = cogview4_transformer.config.patch_size
791796
initial_input_channels = cogview4_transformer.config.in_channels * patch_size**2
792797
new_linear = torch.nn.Linear(
793798
cogview4_transformer.patch_embed.proj.in_features * 2,
@@ -803,7 +808,9 @@ def main(args):
803808
cogview4_transformer.patch_embed.proj = new_linear
804809

805810
assert torch.all(cogview4_transformer.patch_embed.proj.weight[:, initial_input_channels:].data == 0)
806-
cogview4_transformer.register_to_config(in_channels=cogview4_transformer.config.in_channels * 2, out_channels=cogview4_transformer.config.in_channels)
811+
cogview4_transformer.register_to_config(
812+
in_channels=cogview4_transformer.config.in_channels * 2, out_channels=cogview4_transformer.config.in_channels
813+
)
807814

808815
if args.only_target_transformer_blocks:
809816
cogview4_transformer.patch_embed.proj.requires_grad_(True)
@@ -1050,34 +1057,41 @@ def load_model_hook(models, input_dir):
10501057
)
10511058

10521059
# Add noise according for cogview4
1053-
#FIXME: The issue of variable-length training has not been resolved, here it is still extended to the longest one.
1060+
# FIXME: The issue of variable-length training has not been resolved, here it is still extended to the longest one.
10541061
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
10551062
timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device)
10561063
sigmas = noise_scheduler_copy.sigmas[indices].to(device=pixel_latents.device)
10571064
captions = batch["captions"]
1058-
image_seq_lens = torch.tensor(pixel_latents.shape[2] * pixel_latents.shape[3] // patch_size ** 2, dtype=pixel_latents.dtype, device=pixel_latents.device) # H * W / VAE patch_size
1065+
image_seq_lens = torch.tensor(
1066+
pixel_latents.shape[2] * pixel_latents.shape[3] // patch_size**2,
1067+
dtype=pixel_latents.dtype,
1068+
device=pixel_latents.device,
1069+
) # H * W / VAE patch_size
10591070
mu = torch.sqrt(image_seq_lens / 256)
10601071
mu = mu * 0.75 + 0.25
1061-
scale_factors = mu / (mu + (1 / sigmas - 1) ** 1.0).to(dtype=pixel_latents.dtype, device=pixel_latents.device)
1072+
scale_factors = mu / (mu + (1 / sigmas - 1) ** 1.0).to(
1073+
dtype=pixel_latents.dtype, device=pixel_latents.device
1074+
)
10621075
scale_factors = scale_factors.view(len(batch["captions"]), 1, 1, 1)
10631076
noisy_model_input = (1.0 - scale_factors) * pixel_latents + scale_factors * noise
10641077
concatenated_noisy_model_input = torch.cat([noisy_model_input, control_latents], dim=1)
10651078
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
10661079

10671080
with torch.no_grad():
1068-
prompt_embeds, pooled_prompt_embeds, = text_encoding_pipeline.encode_prompt(
1069-
captions, ""
1070-
)
1081+
(
1082+
prompt_embeds,
1083+
pooled_prompt_embeds,
1084+
) = text_encoding_pipeline.encode_prompt(captions, "")
10711085
original_size = (args.resolution, args.resolution)
10721086
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=prompt_embeds.device)
10731087

1074-
target_size = (args.resolution,args.resolution)
1088+
target_size = (args.resolution, args.resolution)
10751089
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=prompt_embeds.device)
10761090

10771091
target_size = target_size.repeat(len(batch["captions"]), 1)
10781092
original_size = original_size.repeat(len(batch["captions"]), 1)
10791093

1080-
#TODO: Should a parameter be set here for passing? This is not present in Flux.
1094+
# TODO: Should a parameter be set here for passing? This is not present in Flux.
10811095
crops_coords_top_left = torch.tensor([(0, 0)], dtype=prompt_embeds.dtype, device=prompt_embeds.device)
10821096
crops_coords_top_left = crops_coords_top_left.repeat(len(batch["captions"]), 1)
10831097
# Predict.
@@ -1108,7 +1122,9 @@ def load_model_hook(models, input_dir):
11081122
target = noise - pixel_latents
11091123

11101124
weighting = weighting.view(len(batch["captions"]), 1, 1, 1)
1111-
loss = torch.mean((weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),1)
1125+
loss = torch.mean(
1126+
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1
1127+
)
11121128
loss = loss.mean()
11131129
accelerator.backward(loss)
11141130

src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def calculate_shift(
6565
mu = m * max_shift + base_shift
6666
return mu
6767

68+
6869
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
6970
def retrieve_timesteps(
7071
scheduler,
@@ -97,10 +98,19 @@ def retrieve_timesteps(
9798
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
9899
second element is the number of inference steps.
99100
"""
101+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
102+
accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
103+
100104
if timesteps is not None and sigmas is not None:
101-
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
102-
if timesteps is not None:
103-
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
105+
if not accepts_timesteps and not accepts_sigmas:
106+
raise ValueError(
107+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
108+
f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
109+
)
110+
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
111+
timesteps = scheduler.timesteps
112+
num_inference_steps = len(timesteps)
113+
elif timesteps is not None and sigmas is None:
104114
if not accepts_timesteps:
105115
raise ValueError(
106116
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
@@ -109,9 +119,8 @@ def retrieve_timesteps(
109119
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
110120
timesteps = scheduler.timesteps
111121
num_inference_steps = len(timesteps)
112-
elif sigmas is not None:
113-
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
114-
if not accept_sigmas:
122+
elif timesteps is None and sigmas is not None:
123+
if not accepts_sigmas:
115124
raise ValueError(
116125
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
117126
f" sigmas schedules. Please check whether you are using the correct scheduler."
@@ -630,8 +639,10 @@ def __call__(
630639
self.scheduler.config.get("base_shift", 0.25),
631640
self.scheduler.config.get("max_shift", 0.75),
632641
)
633-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas, mu=mu)
634-
642+
timesteps, num_inference_steps = retrieve_timesteps(
643+
self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
644+
)
645+
self._num_timesteps = len(timesteps)
635646
# Denoising loop
636647
transformer_dtype = self.transformer.dtype
637648
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)

0 commit comments

Comments
 (0)