From 45b6cb6f2c0284698b0c4f1904adab6d9062517f Mon Sep 17 00:00:00 2001
From: zR <2448370773@qq.com>
Date: Sat, 5 Oct 2024 13:05:32 +0800
Subject: [PATCH 01/30] merge 9588
---
scripts/convert_cogview3_to_diffusers.py | 183 +++++
show_model.py | 91 +++
show_model_cogview.py | 25 +
src/diffusers/__init__.py | 4 +
src/diffusers/models/__init__.py | 2 +
src/diffusers/models/attention_processor.py | 5 +-
src/diffusers/models/embeddings.py | 79 +-
src/diffusers/models/normalization.py | 45 ++
src/diffusers/models/transformers/__init__.py | 1 +
.../transformers/transformer_cogview3plus.py | 364 +++++++++
src/diffusers/pipelines/__init__.py | 4 +
src/diffusers/pipelines/auto_pipeline.py | 2 +
src/diffusers/pipelines/cogview3/__init__.py | 47 ++
.../cogview3/pipeline_cogview3plus.py | 707 ++++++++++++++++++
.../pipelines/cogview3/pipeline_output.py | 21 +
.../dummy_torch_and_transformers_objects.py | 15 +
16 files changed, 1590 insertions(+), 5 deletions(-)
create mode 100644 scripts/convert_cogview3_to_diffusers.py
create mode 100644 show_model.py
create mode 100644 show_model_cogview.py
create mode 100644 src/diffusers/models/transformers/transformer_cogview3plus.py
create mode 100644 src/diffusers/pipelines/cogview3/__init__.py
create mode 100644 src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
create mode 100644 src/diffusers/pipelines/cogview3/pipeline_output.py
diff --git a/scripts/convert_cogview3_to_diffusers.py b/scripts/convert_cogview3_to_diffusers.py
new file mode 100644
index 000000000000..0af757969293
--- /dev/null
+++ b/scripts/convert_cogview3_to_diffusers.py
@@ -0,0 +1,183 @@
+"""
+Convert a CogView3 checkpoint to the Diffusers format.
+
+This script converts a CogView3 checkpoint to the Diffusers format, which can then be used
+with the Diffusers library.
+
+Example usage:
+ python scripts/convert_cogview3_to_diffusers.py \
+ --original_state_dict_repo_id "THUDM/cogview3" \
+ --filename "cogview3.pt" \
+ --transformer \
+ --output_path "./cogview3_diffusers" \
+ --dtype "bf16"
+
+Alternatively, if you have a local checkpoint:
+ python scripts/convert_cogview3_to_diffusers.py \
+ --checkpoint_path '/raid/.cache/huggingface/models--ZP2HF--CogView3-SAT/snapshots/ca86ce9ba94f9a7f2dd109e7a59e4c8ad04121be/cogview3plus_3b/1/mp_rank_00_model_states.pt' \
+ --transformer \
+ --output_path "/raid/yiyi/cogview3_diffusers" \
+ --dtype "bf16"
+
+Arguments:
+ --original_state_dict_repo_id: The Hugging Face repo ID containing the original checkpoint.
+ --filename: The filename of the checkpoint in the repo (default: "flux.safetensors").
+ --checkpoint_path: Path to a local checkpoint file (alternative to repo_id and filename).
+ --transformer: Flag to convert the transformer model.
+ --output_path: The path to save the converted model.
+ --dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32").
+
+Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
+"""
+
+import argparse
+from contextlib import nullcontext
+
+import torch
+from accelerate import init_empty_weights
+from huggingface_hub import hf_hub_download
+
+from diffusers import CogView3PlusTransformer2DModel
+from diffusers.utils.import_utils import is_accelerate_available
+
+
+CTX = init_empty_weights if is_accelerate_available else nullcontext
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
+parser.add_argument("--filename", default="flux.safetensors", type=str)
+parser.add_argument("--checkpoint_path", default=None, type=str)
+parser.add_argument("--transformer", action="store_true")
+parser.add_argument("--output_path", type=str)
+parser.add_argument("--dtype", type=str, default="bf16")
+
+args = parser.parse_args()
+
+
+def load_original_checkpoint(args):
+ if args.original_state_dict_repo_id is not None:
+ ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
+ elif args.checkpoint_path is not None:
+ ckpt_path = args.checkpoint_path
+ else:
+ raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
+
+ original_state_dict = torch.load(ckpt_path, map_location="cpu")
+ return original_state_dict
+
+
+# this is specific to `AdaLayerNormContinuous`:
+# diffusers imnplementation split the linear projection into the scale, shift while CogView3 split it tino shift, scale
+def swap_scale_shift(weight, dim):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+
+def convert_cogview3_transformer_checkpoint_to_diffusers(original_state_dict):
+ new_state_dict = {}
+
+ # Convert pos_embed
+ new_state_dict["pos_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
+ new_state_dict["pos_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
+ new_state_dict["pos_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
+ new_state_dict["pos_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
+
+ # Convert time_text_embed
+ new_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
+ "time_embed.0.weight"
+ )
+ new_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("time_embed.0.bias")
+ new_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
+ "time_embed.2.weight"
+ )
+ new_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_embed.2.bias")
+ new_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop("label_emb.0.0.weight")
+ new_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop("label_emb.0.0.bias")
+ new_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop("label_emb.0.2.weight")
+ new_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop("label_emb.0.2.bias")
+
+ # Convert transformer blocks
+ for i in range(30):
+ block_prefix = f"transformer_blocks.{i}."
+ old_prefix = f"transformer.layers.{i}."
+ adaln_prefix = f"mixins.adaln.adaln_modules.{i}."
+
+ new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
+ new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")
+
+ qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
+ qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
+ q, k, v = qkv_weight.chunk(3, dim=0)
+ q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
+
+ new_state_dict[block_prefix + "attn.to_q.weight"] = q
+ new_state_dict[block_prefix + "attn.to_q.bias"] = q_bias
+ new_state_dict[block_prefix + "attn.to_k.weight"] = k
+ new_state_dict[block_prefix + "attn.to_k.bias"] = k_bias
+ new_state_dict[block_prefix + "attn.to_v.weight"] = v
+ new_state_dict[block_prefix + "attn.to_v.bias"] = v_bias
+
+ new_state_dict[block_prefix + "attn.to_out.0.weight"] = original_state_dict.pop(
+ old_prefix + "attention.dense.weight"
+ )
+ new_state_dict[block_prefix + "attn.to_out.0.bias"] = original_state_dict.pop(
+ old_prefix + "attention.dense.bias"
+ )
+
+ new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(
+ old_prefix + "mlp.dense_h_to_4h.weight"
+ )
+ new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop(
+ old_prefix + "mlp.dense_h_to_4h.bias"
+ )
+ new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(
+ old_prefix + "mlp.dense_4h_to_h.weight"
+ )
+ new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias")
+
+ # Convert final norm and projection
+ new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
+ original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0
+ )
+ new_state_dict["norm_out.linear.bias"] = swap_scale_shift(
+ original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0
+ )
+ new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight")
+ new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias")
+
+ return new_state_dict
+
+
+def main(args):
+ original_ckpt = load_original_checkpoint(args)
+ original_ckpt = original_ckpt["module"]
+ original_ckpt = {k.replace("model.diffusion_model.", ""): v for k, v in original_ckpt.items()}
+
+ original_dtype = next(iter(original_ckpt.values())).dtype
+ dtype = None
+ if args.dtype is None:
+ dtype = original_dtype
+ elif args.dtype == "fp16":
+ dtype = torch.float16
+ elif args.dtype == "bf16":
+ dtype = torch.bfloat16
+ elif args.dtype == "fp32":
+ dtype = torch.float32
+ else:
+ raise ValueError(f"Unsupported dtype: {args.dtype}")
+
+ if args.transformer:
+ converted_transformer_state_dict = convert_cogview3_transformer_checkpoint_to_diffusers(original_ckpt)
+ transformer = CogView3PlusTransformer2DModel()
+ transformer.load_state_dict(converted_transformer_state_dict, strict=True)
+
+ print(f"Saving CogView3 Transformer in Diffusers format in {args.output_path}/transformer")
+ transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
+
+ if len(original_ckpt) > 0:
+ print(f"Warning: {len(original_ckpt)} keys were not converted and will be saved as is: {original_ckpt.keys()}")
+
+
+if __name__ == "__main__":
+ main(args)
diff --git a/show_model.py b/show_model.py
new file mode 100644
index 000000000000..0127243117bd
--- /dev/null
+++ b/show_model.py
@@ -0,0 +1,91 @@
+import torch
+from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
+from diffusers import AutoencoderKL
+from huggingface_hub import hf_hub_download
+from sgm.models.autoencoder import AutoencodingEngine
+
+# (1) create vae_sat
+# AutoencodingEngine initialization arguments:
+encoder_config={'target': 'sgm.modules.diffusionmodules.model.Encoder', 'params': {'attn_type': 'vanilla', 'double_z': True, 'z_channels': 16, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 4, 8, 8], 'num_res_blocks': 3, 'attn_resolutions': [], 'mid_attn': False, 'dropout': 0.0}}
+decoder_config={'target': 'sgm.modules.diffusionmodules.model.Decoder', 'params': {'attn_type': 'vanilla', 'double_z': True, 'z_channels': 16, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 4, 8, 8], 'num_res_blocks': 3, 'attn_resolutions': [], 'mid_attn': False, 'dropout': 0.0}}
+loss_config={'target': 'torch.nn.Identity'}
+regularizer_config={'target': 'sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer'}
+optimizer_config=None
+lr_g_factor=1.0
+ckpt_path="/raid/.cache/huggingface/models--ZP2HF--CogView3-SAT/snapshots/ca86ce9ba94f9a7f2dd109e7a59e4c8ad04121be/3plus_ae/imagekl_ch16.pt"
+ignore_keys= []
+kwargs = {"monitor": "val/rec_loss"}
+vae_sat = AutoencodingEngine(
+ encoder_config=encoder_config,
+ decoder_config=decoder_config,
+ loss_config=loss_config,
+ regularizer_config=regularizer_config,
+ optimizer_config=optimizer_config,
+ lr_g_factor=lr_g_factor,
+ ckpt_path=ckpt_path,
+ ignore_keys=ignore_keys,
+ **kwargs)
+
+
+
+# (2) create vae (diffusers)
+ckpt_path_vae_cogview3 = hf_hub_download(repo_id="ZP2HF/CogView3-SAT", subfolder="3plus_ae", filename="imagekl_ch16.pt")
+cogview3_ckpt = torch.load(ckpt_path_vae_cogview3, map_location='cpu')["state_dict"]
+
+in_channels = 3 # Inferred from encoder.conv_in.weight shape
+out_channels = 3 # Inferred from decoder.conv_out.weight shape
+down_block_types = ("DownEncoderBlock2D",) * 4 # Inferred from the presence of 4 encoder.down blocks
+up_block_types = ("UpDecoderBlock2D",) * 4 # Inferred from the presence of 4 decoder.up blocks
+block_out_channels = (128, 512, 1024, 1024) # Inferred from the channel sizes in encoder.down blocks
+layers_per_block = 3 # Inferred from the number of blocks in each encoder.down and decoder.up
+act_fn = "silu" # This is the default, cannot be inferred from state_dict
+latent_channels = 16 # Inferred from decoder.conv_in.weight shape
+norm_num_groups = 32 # This is the default, cannot be inferred from state_dict
+sample_size = 1024 # This is the default, cannot be inferred from state_dict
+scaling_factor = 0.18215 # This is the default, cannot be inferred from state_dict
+force_upcast = True # This is the default, cannot be inferred from state_dict
+use_quant_conv = False # Inferred from the presence of encoder.conv_out
+use_post_quant_conv = False # Inferred from the presence of decoder.conv_in
+mid_block_add_attention = False # Inferred from the absence of attention layers in mid blocks
+
+vae = AutoencoderKL(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ down_block_types=down_block_types,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ latent_channels=latent_channels,
+ norm_num_groups=norm_num_groups,
+ sample_size=sample_size,
+ scaling_factor=scaling_factor,
+ force_upcast=force_upcast,
+ use_quant_conv=use_quant_conv,
+ use_post_quant_conv=use_post_quant_conv,
+ mid_block_add_attention=mid_block_add_attention,
+)
+
+vae.eval()
+vae_sat.eval()
+
+converted_vae_state_dict = convert_ldm_vae_checkpoint(cogview3_ckpt, vae.config)
+vae.load_state_dict(converted_vae_state_dict, strict=False)
+
+# (3) run forward pass for both models
+
+# [2, 16, 128, 128] -> [2, 3, 1024, 1024
+z = torch.load("z.pt").float().to("cpu")
+
+with torch.no_grad():
+ print(" ")
+ print(f" running forward pass for diffusers vae")
+ out = vae.decode(z).sample
+ print(f" ")
+ print(f" running forward pass for sgm vae")
+ out_sat = vae_sat.decode(z)
+
+print(f" output shape: {out.shape}")
+print(f" expected output shape: {out_sat.shape}")
+assert out.shape == out_sat.shape
+assert (out - out_sat).abs().max() < 1e-4, f"max diff: {(out - out_sat).abs().max()}"
\ No newline at end of file
diff --git a/show_model_cogview.py b/show_model_cogview.py
new file mode 100644
index 000000000000..5314930cb127
--- /dev/null
+++ b/show_model_cogview.py
@@ -0,0 +1,25 @@
+import torch
+from diffusers import CogView3PlusTransformer2DModel
+
+model = CogView3PlusTransformer2DModel.from_pretrained("/share/home/zyx/Models/CogView3Plus_hf/transformer",torch_dtype=torch.bfloat16)
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+model.to(device)
+
+batch_size = 1
+hidden_states = torch.ones((batch_size, 16, 256, 256), device=device, dtype=torch.bfloat16)
+timestep = torch.full((batch_size,), 999.0, device=device, dtype=torch.bfloat16)
+y = torch.ones((batch_size, 1536), device=device, dtype=torch.bfloat16)
+
+# 模拟调用 forward 方法
+outputs = model(
+ hidden_states=hidden_states, # hidden_states 输入
+ timestep=timestep, # timestep 输入
+ y=y, # 标签输入
+ block_controlnet_hidden_states=None, # 如果不需要,可以忽略
+ return_dict=True, # 保持默认值
+ target_size=[(2048, 2048)],
+)
+
+# 输出模型结果
+print("Output shape:", outputs.sample.shape)
\ No newline at end of file
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 4214a4699ec8..978e7047e666 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -84,6 +84,7 @@
"AutoencoderOobleck",
"AutoencoderTiny",
"CogVideoXTransformer3DModel",
+ "CogView3PlusTransformer2DModel",
"ConsistencyDecoderVAE",
"ControlNetModel",
"ControlNetXSAdapter",
@@ -258,6 +259,7 @@
"CogVideoXImageToVideoPipeline",
"CogVideoXPipeline",
"CogVideoXVideoToVideoPipeline",
+ "CogView3PlusPipeline",
"CycleDiffusionPipeline",
"FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline",
@@ -558,6 +560,7 @@
AutoencoderOobleck,
AutoencoderTiny,
CogVideoXTransformer3DModel,
+ CogView3PlusTransformer2DModel,
ConsistencyDecoderVAE,
ControlNetModel,
ControlNetXSAdapter,
@@ -710,6 +713,7 @@
CogVideoXImageToVideoPipeline,
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
+ CogView3PlusPipeline,
CycleDiffusionPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index f0dd7248c117..4dda8c36ba1c 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -54,6 +54,7 @@
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
+ _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
@@ -98,6 +99,7 @@
from .transformers import (
AuraFlowTransformer2DModel,
CogVideoXTransformer3DModel,
+ CogView3PlusTransformer2DModel,
DiTTransformer2DModel,
DualTransformer2DModel,
FluxTransformer2DModel,
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 9f9bc5a46e10..2ef3c6a80830 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -122,6 +122,7 @@ def __init__(
out_dim: int = None,
context_pre_only=None,
pre_only=False,
+ layrnorm_elementwise_affine: bool = True,
):
super().__init__()
@@ -179,8 +180,8 @@ def __init__(
self.norm_q = None
self.norm_k = None
elif qk_norm == "layer_norm":
- self.norm_q = nn.LayerNorm(dim_head, eps=eps)
- self.norm_k = nn.LayerNorm(dim_head, eps=eps)
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=layrnorm_elementwise_affine)
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=layrnorm_elementwise_affine)
elif qk_norm == "fp32_layer_norm":
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index c250df29afbe..cddf46cfea77 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -714,6 +714,58 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
return freqs_cos, freqs_sin
+class CogView3PlusPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 16,
+ hidden_size: int = 2560,
+ patch_size: int = 2,
+ text_hidden_size: int = 4096,
+ pos_embed_max_size: int = 128,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_size = hidden_size
+ self.patch_size = patch_size
+ self.text_hidden_size = text_hidden_size
+ self.pos_embed_max_size = pos_embed_max_size
+ # Linear projection for image patches
+ self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
+
+ # Linear projection for text embeddings
+ self.text_proj = nn.Linear(text_hidden_size, hidden_size)
+
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
+ pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
+
+ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None) -> torch.Tensor:
+ batch_size, channel, height, width = hidden_states.shape
+ if height % self.patch_size != 0 or width % self.patch_size != 0:
+ raise ValueError("Height and width must be divisible by patch size")
+ height = height // self.patch_size
+ width = width // self.patch_size
+ hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
+ hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
+ hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
+
+ # Project the patches
+ hidden_states = self.proj(hidden_states)
+ encoder_hidden_states = self.text_proj(encoder_hidden_states)
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ # Calculate text_length
+ text_length = encoder_hidden_states.shape[1]
+
+ image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
+ text_pos_embed = torch.zeros(
+ (text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
+ )
+ pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...]
+
+ return (hidden_states + pos_embed).to(hidden_states.dtype)
+
+
class TimestepEmbedding(nn.Module):
def __init__(
self,
@@ -1018,6 +1070,27 @@ def forward(self, image_embeds: torch.Tensor):
return self.norm(x)
+class CogView3CombineTimestepLabelEmbedding(nn.Module):
+ def __init__(self, time_embed_dim, label_embed_dim, in_channels=2560):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=1)
+ self.timestep_embedder = TimestepEmbedding(in_channels=in_channels, time_embed_dim=time_embed_dim)
+ self.label_embedder = nn.Sequential(
+ nn.Linear(label_embed_dim, time_embed_dim),
+ nn.SiLU(),
+ nn.Linear(time_embed_dim, time_embed_dim),
+ )
+
+ def forward(self, timestep, class_labels, hidden_dtype=None):
+ t_proj = self.time_proj(timestep)
+ t_emb = self.timestep_embedder(t_proj.to(dtype=hidden_dtype))
+ label_emb = self.label_embedder(class_labels)
+ emb = t_emb + label_emb
+
+ return emb
+
+
class CombinedTimestepLabelEmbeddings(nn.Module):
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
super().__init__()
@@ -1038,11 +1111,11 @@ def forward(self, timestep, class_labels, hidden_dtype=None):
class CombinedTimestepTextProjEmbeddings(nn.Module):
- def __init__(self, embedding_dim, pooled_projection_dim):
+ def __init__(self, embedding_dim, pooled_projection_dim, timesteps_dim=256):
super().__init__()
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
- self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(self, timestep, pooled_projection):
diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py
index 5740fed9f30c..21e9d3cd6fc5 100644
--- a/src/diffusers/models/normalization.py
+++ b/src/diffusers/models/normalization.py
@@ -355,6 +355,51 @@ def forward(
return x
+class CogView3PlusAdaLayerNormZeroTextImage(nn.Module):
+ r"""
+ Norm layer adaptive layer norm zero (adaLN-Zero).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the embeddings dictionary.
+ """
+
+ def __init__(self, embedding_dim: int, dim: int):
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
+ self.norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
+ self.norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ context: torch.Tensor,
+ emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ emb = self.linear(self.silu(emb))
+ (
+ shift_msa,
+ scale_msa,
+ gate_msa,
+ shift_mlp,
+ scale_mlp,
+ gate_mlp,
+ c_shift_msa,
+ c_scale_msa,
+ c_gate_msa,
+ c_shift_mlp,
+ c_scale_mlp,
+ c_gate_mlp,
+ ) = emb.chunk(12, dim=1)
+ normed_x = self.norm_x(x)
+ normed_context = self.norm_c(context)
+ x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ context = normed_context * (1 + c_scale_msa[:, None]) + c_shift_msa[:, None]
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, context, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp
+
+
class CogVideoXLayerNormZero(nn.Module):
def __init__(
self,
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index d55dfe57d6f3..58787c079ea8 100644
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -14,6 +14,7 @@
from .stable_audio_transformer import StableAudioDiTModel
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
+ from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
new file mode 100644
index 000000000000..8d7e37f0c925
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -0,0 +1,364 @@
+# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Any, Dict, Union
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...models.attention import FeedForward
+from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
+from ...models.modeling_utils import ModelMixin
+from ...models.normalization import AdaLayerNormContinuous
+from ...utils import is_torch_version, logging
+from ..embeddings import CogView3PlusPatchEmbed, CombinedTimestepTextProjEmbeddings
+from ..modeling_outputs import Transformer2DModelOutput
+from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class CogView3PlusTransformerBlock(nn.Module):
+ """
+ Updated CogView3 Transformer Block to align with AdalnAttentionMixin style, simplified with qk_ln always True.
+ """
+
+ def __init__(
+ self,
+ dim: int = 2560,
+ num_attention_heads: int = 64,
+ attention_head_dim: int = 40,
+ time_embed_dim: int = 512,
+ ):
+ super().__init__()
+
+ self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim)
+
+ self.attn = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ out_dim=dim,
+ bias=True,
+ qk_norm="layer_norm",
+ layrnorm_elementwise_affine=False,
+ eps=1e-6,
+ )
+
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
+
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ emb: torch.Tensor,
+ text_length: int,
+ ) -> torch.Tensor:
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_length], hidden_states[:, text_length:]
+
+ # norm1
+ (
+ norm_hidden_states,
+ gate_msa,
+ shift_mlp,
+ scale_mlp,
+ gate_mlp,
+ norm_encoder_hidden_states,
+ c_gate_msa,
+ c_shift_mlp,
+ c_scale_mlp,
+ c_gate_mlp,
+ ) = self.norm1(hidden_states, encoder_hidden_states, emb)
+
+ # Attention
+ attn_input = torch.cat((norm_encoder_hidden_states, norm_hidden_states), dim=1)
+ attn_output = self.attn(hidden_states=attn_input)
+ context_attn_output, attn_output = attn_output[:, :text_length], attn_output[:, text_length:]
+
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ norm_hidden_states = torch.cat((norm_encoder_hidden_states, norm_hidden_states), dim=1)
+
+ ff_output = self.ff(norm_hidden_states)
+
+ context_ff_output, ff_output = ff_output[:, :text_length], ff_output[:, text_length:]
+
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
+ hidden_states = torch.cat((encoder_hidden_states, hidden_states), dim=1)
+
+ return hidden_states
+
+
+class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+ """
+ The Transformer model introduced in CogView3.
+
+ Reference: https://arxiv.org/abs/2403.05121
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: int = 128,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ num_layers: int = 30,
+ attention_head_dim: int = 40,
+ num_attention_heads: int = 64,
+ out_channels: int = 16,
+ encoder_hidden_states_dim: int = 4096,
+ pooled_projection_dim: int = 1536,
+ pos_embed_max_size: int = 128,
+ time_embed_dim: int = 512,
+ ):
+ super().__init__()
+ self.out_channels = out_channels
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
+
+ self.pos_embed = CogView3PlusPatchEmbed(
+ in_channels=self.config.in_channels,
+ hidden_size=self.inner_dim,
+ patch_size=self.config.patch_size,
+ text_hidden_size=self.config.encoder_hidden_states_dim,
+ pos_embed_max_size=self.config.pos_embed_max_size,
+ )
+
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
+ embedding_dim=self.config.time_embed_dim,
+ pooled_projection_dim=self.config.pooled_projection_dim,
+ timesteps_dim=self.inner_dim,
+ )
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ CogView3PlusTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_head_dim=self.config.attention_head_dim,
+ time_embed_dim=self.config.time_embed_dim,
+ )
+ for _ in range(self.config.num_layers)
+ ]
+ )
+
+ self.norm_out = AdaLayerNormContinuous(
+ embedding_dim=self.inner_dim,
+ conditioning_embedding_dim=self.config.time_embed_dim,
+ elementwise_affine=False,
+ eps=1e-6,
+ )
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
+
+ self.gradient_checkpointing = False
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedJointAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ pooled_projections: torch.FloatTensor = None,
+ timestep: torch.LongTensor = None,
+ return_dict: bool = True,
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ """
+ The [`CogView3PlusTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.FloatTensor`): Input `hidden_states`.
+ timestep (`torch.LongTensor`): Indicates denoising step.
+ y (`torch.LongTensor`, *optional*): 标签输入,用于获取标签嵌入。
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`): A list of tensors for residuals.
+ joint_attention_kwargs (`dict`, *optional*): Additional kwargs for the attention processor.
+ return_dict (`bool`, *optional*, defaults to `True`): Whether to return a `Transformer2DModelOutput`.
+
+ Returns:
+ Output tensor or `Transformer2DModelOutput`.
+ """
+
+ height, width = hidden_states.shape[-2:]
+ text_length = encoder_hidden_states.shape[1]
+
+ hidden_states = self.pos_embed(
+ hidden_states, encoder_hidden_states
+ ) # takes care of adding positional embeddings too.
+ emb = self.time_text_embed(timestep, pooled_projections)
+
+ for index_block, block in enumerate(self.transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ emb,
+ text_length,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ emb=emb,
+ text_length=text_length,
+ )
+
+ hidden_states = hidden_states[:, text_length:]
+ hidden_states = self.norm_out(hidden_states, emb)
+ hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels)
+ # unpatchify
+ patch_size = self.config.patch_size
+ height = height // patch_size
+ width = width // patch_size
+
+ hidden_states = hidden_states.reshape(
+ shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size)
+ )
+ hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
+ )
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 3b6cde17c8a3..7ff219efdf5f 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -145,6 +145,9 @@
"CogVideoXImageToVideoPipeline",
"CogVideoXVideoToVideoPipeline",
]
+ _import_structure["cogview3"] = [
+ "CogView3PlusPipeline",
+ ]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
@@ -469,6 +472,7 @@
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .cogvideo import CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline
+ from .cogview3 import CogView3PlusPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index e3e78d0663fa..1dc1711dce0b 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -20,6 +20,7 @@
from ..configuration_utils import ConfigMixin
from ..utils import is_sentencepiece_available
from .aura_flow import AuraFlowPipeline
+from .cogview3 import CogView3PlusPipeline
from .controlnet import (
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
@@ -118,6 +119,7 @@
("flux", FluxPipeline),
("flux-controlnet", FluxControlNetPipeline),
("lumina", LuminaText2ImgPipeline),
+ ("cogview3", CogView3PlusPipeline),
]
)
diff --git a/src/diffusers/pipelines/cogview3/__init__.py b/src/diffusers/pipelines/cogview3/__init__.py
new file mode 100644
index 000000000000..50895251ba0b
--- /dev/null
+++ b/src/diffusers/pipelines/cogview3/__init__.py
@@ -0,0 +1,47 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_additional_imports = {}
+_import_structure = {"pipeline_output": ["CogView3PlusPipelineOutput"]}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_cogview3plus"] = ["CogView3PlusPipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .pipeline_cogview3plus import CogView3PlusPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
+ for name, value in _additional_imports.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
new file mode 100644
index 000000000000..3413fcda9e85
--- /dev/null
+++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
@@ -0,0 +1,707 @@
+# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import T5EncoderModel, T5TokenizerFast
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...models.autoencoders import AutoencoderKL
+from ...models.transformers.transformer_cogview3plus import CogView3PlusTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import CogView3PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import CogView3PlusPipeline
+
+ >>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3-Plus-3B", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> prompt = "A cat holding a sign that says hello world"
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
+ >>> image.save("cat.png")
+ ```
+"""
+
+
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.16,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class CogView3PlusPipeline(DiffusionPipeline, PeftAdapterMixin, FromOriginalModelMixin):
+ r"""
+ The CogView3 pipeline for text-to-image generation.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`CogView3PlusTransformerBlock`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: T5EncoderModel,
+ tokenizer: T5TokenizerFast,
+ transformer: CogView3PlusTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 64
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 256,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)[0].to(
+ dtype=dtype, device=device
+ )
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, PeftAdapterMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, PeftAdapterMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ height = 2 * (int(height) // self.vae_scale_factor)
+ width = 2 * (int(width) // self.vae_scale_factor)
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+
+ return latents, latent_image_ids
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ prompt_3 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
+ will be used instead
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.cogview3.CogView3PlusPipelineOutput`] or `tuple`:
+ [`~pipelines.cogview3.CogView3PlusPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return CogView3PipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/cogview3/pipeline_output.py b/src/diffusers/pipelines/cogview3/pipeline_output.py
new file mode 100644
index 000000000000..3891dd51e691
--- /dev/null
+++ b/src/diffusers/pipelines/cogview3/pipeline_output.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class CogView3PipelineOutput(BaseOutput):
+ """
+ Output class for CogView3 pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 1927fc8cd4d3..69ce91da0f25 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -422,6 +422,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class CogView3PlusPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class HunyuanDiTControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
From 8abfe00f6382f641cedaa0abfbe90d41693d107f Mon Sep 17 00:00:00 2001
From: zR <2448370773@qq.com>
Date: Sat, 5 Oct 2024 15:23:58 +0800
Subject: [PATCH 02/30] max_shard_size="5GB" for colab running
---
scripts/convert_cogview3_to_diffusers.py | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/scripts/convert_cogview3_to_diffusers.py b/scripts/convert_cogview3_to_diffusers.py
index 0af757969293..1c6ff5817051 100644
--- a/scripts/convert_cogview3_to_diffusers.py
+++ b/scripts/convert_cogview3_to_diffusers.py
@@ -6,7 +6,7 @@
Example usage:
python scripts/convert_cogview3_to_diffusers.py \
- --original_state_dict_repo_id "THUDM/cogview3" \
+ --original_state_dict_repo_id "THUDM/cogview3-sat" \
--filename "cogview3.pt" \
--transformer \
--output_path "./cogview3_diffusers" \
@@ -14,7 +14,7 @@
Alternatively, if you have a local checkpoint:
python scripts/convert_cogview3_to_diffusers.py \
- --checkpoint_path '/raid/.cache/huggingface/models--ZP2HF--CogView3-SAT/snapshots/ca86ce9ba94f9a7f2dd109e7a59e4c8ad04121be/cogview3plus_3b/1/mp_rank_00_model_states.pt' \
+ --checkpoint_path 'your path/cogview3plus_3b/1/mp_rank_00_model_states.pt' \
--transformer \
--output_path "/raid/yiyi/cogview3_diffusers" \
--dtype "bf16"
@@ -26,6 +26,7 @@
--transformer: Flag to convert the transformer model.
--output_path: The path to save the converted model.
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32").
+ Default is "bf16" because CogView3 uses bfloat16 for Training.
Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
"""
@@ -173,7 +174,7 @@ def main(args):
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
print(f"Saving CogView3 Transformer in Diffusers format in {args.output_path}/transformer")
- transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
+ transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer", max_shard_size="5GB")
if len(original_ckpt) > 0:
print(f"Warning: {len(original_ckpt)} keys were not converted and will be saved as is: {original_ckpt.keys()}")
From d668ad91c32bfb6768ae63a3a4e3cab07ca749d2 Mon Sep 17 00:00:00 2001
From: Aryan
Date: Mon, 7 Oct 2024 01:13:06 +0200
Subject: [PATCH 03/30] conversion script updates; modeling test; refactor
transformer
---
scripts/convert_cogview3_to_diffusers.py | 123 ++++++++++++-----
src/diffusers/models/attention_processor.py | 6 +-
.../transformers/transformer_cogview3plus.py | 124 ++++++++++--------
.../test_models_transformer_cogview3plus.py | 80 +++++++++++
4 files changed, 239 insertions(+), 94 deletions(-)
create mode 100644 tests/models/transformers/test_models_transformer_cogview3plus.py
diff --git a/scripts/convert_cogview3_to_diffusers.py b/scripts/convert_cogview3_to_diffusers.py
index 1c6ff5817051..d1a7fd51654e 100644
--- a/scripts/convert_cogview3_to_diffusers.py
+++ b/scripts/convert_cogview3_to_diffusers.py
@@ -36,46 +36,41 @@
import torch
from accelerate import init_empty_weights
-from huggingface_hub import hf_hub_download
+from transformers import T5EncoderModel, T5Tokenizer
-from diffusers import CogView3PlusTransformer2DModel
+from diffusers import AutoencoderKL, CogView3PlusPipeline, CogView3PlusTransformer2DModel, DDIMScheduler
+from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
+TOKENIZER_MAX_LENGTH = 224
+
parser = argparse.ArgumentParser()
-parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
-parser.add_argument("--filename", default="flux.safetensors", type=str)
-parser.add_argument("--checkpoint_path", default=None, type=str)
-parser.add_argument("--transformer", action="store_true")
-parser.add_argument("--output_path", type=str)
+parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
+parser.add_argument("--vae_checkpoint_path", default=None, type=str)
+parser.add_argument("--output_path", required=True, type=str)
+parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
+parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
parser.add_argument("--dtype", type=str, default="bf16")
args = parser.parse_args()
-def load_original_checkpoint(args):
- if args.original_state_dict_repo_id is not None:
- ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
- elif args.checkpoint_path is not None:
- ckpt_path = args.checkpoint_path
- else:
- raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
-
- original_state_dict = torch.load(ckpt_path, map_location="cpu")
- return original_state_dict
-
-
# this is specific to `AdaLayerNormContinuous`:
-# diffusers imnplementation split the linear projection into the scale, shift while CogView3 split it tino shift, scale
+# diffusers implementation split the linear projection into the scale, shift while CogView3 split it tino shift, scale
def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
-def convert_cogview3_transformer_checkpoint_to_diffusers(original_state_dict):
+def convert_cogview3_transformer_checkpoint_to_diffusers(ckpt_path):
+ original_state_dict = torch.load(ckpt_path, map_location="cpu")
+ original_state_dict = original_state_dict["module"]
+ original_state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in original_state_dict.items()}
+
new_state_dict = {}
# Convert pos_embed
@@ -150,16 +145,13 @@ def convert_cogview3_transformer_checkpoint_to_diffusers(original_state_dict):
return new_state_dict
+def convert_cogview3_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
+ original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
+ return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
+
+
def main(args):
- original_ckpt = load_original_checkpoint(args)
- original_ckpt = original_ckpt["module"]
- original_ckpt = {k.replace("model.diffusion_model.", ""): v for k, v in original_ckpt.items()}
-
- original_dtype = next(iter(original_ckpt.values())).dtype
- dtype = None
- if args.dtype is None:
- dtype = original_dtype
- elif args.dtype == "fp16":
+ if args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16
@@ -168,16 +160,75 @@ def main(args):
else:
raise ValueError(f"Unsupported dtype: {args.dtype}")
- if args.transformer:
- converted_transformer_state_dict = convert_cogview3_transformer_checkpoint_to_diffusers(original_ckpt)
+ transformer = None
+ vae = None
+
+ if args.transformer_checkpoint_path is not None:
+ converted_transformer_state_dict = convert_cogview3_transformer_checkpoint_to_diffusers(
+ args.transformer_checkpoint_path
+ )
transformer = CogView3PlusTransformer2DModel()
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
+ transformer = transformer.to(dtype=dtype)
+
+ if args.vae_checkpoint_path is not None:
+ vae_config = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ("DownEncoderBlock2D",) * 4,
+ "up_block_types": ("UpDecoderBlock2D",) * 4,
+ "block_out_channels": (128, 512, 1024, 1024),
+ "layers_per_block": 3,
+ "act_fn": "silu",
+ "latent_channels": 16,
+ "norm_num_groups": 32,
+ "sample_size": 1024,
+ "scaling_factor": 0.18215,
+ "force_upcast": True,
+ "use_quant_conv": False,
+ "use_post_quant_conv": False,
+ "mid_block_add_attention": False,
+ }
+ converted_vae_state_dict = convert_cogview3_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
+ vae = AutoencoderKL(**vae_config)
+ vae.load_state_dict(converted_vae_state_dict, strict=True)
+ vae = vae.to(dtype=dtype)
+
+ text_encoder_id = "google/t5-v1_1-xxl"
+ tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
+ text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
+
+ # Apparently, the conversion does not work anymore without this :shrug:
+ for param in text_encoder.parameters():
+ param.data = param.data.contiguous()
+
+ # TODO: figure out the correct scheduler
+ scheduler = DDIMScheduler.from_config(
+ {
+ "snr_shift_scale": 4.0, # This is different from default
+ "beta_end": 0.012,
+ "beta_schedule": "scaled_linear",
+ "beta_start": 0.00085,
+ "clip_sample": False,
+ "num_train_timesteps": 1000,
+ "prediction_type": "v_prediction",
+ "rescale_betas_zero_snr": True,
+ "set_alpha_to_one": True,
+ "timestep_spacing": "trailing",
+ }
+ )
- print(f"Saving CogView3 Transformer in Diffusers format in {args.output_path}/transformer")
- transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer", max_shard_size="5GB")
+ pipe = CogView3PlusPipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ vae=vae,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
- if len(original_ckpt) > 0:
- print(f"Warning: {len(original_ckpt)} keys were not converted and will be saved as is: {original_ckpt.keys()}")
+ # This is necessary for users with insufficient memory, such as those using Colab and notebooks, as it can
+ # save some memory used for model loading.
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
if __name__ == "__main__":
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 2ef3c6a80830..d333590982e3 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -122,7 +122,7 @@ def __init__(
out_dim: int = None,
context_pre_only=None,
pre_only=False,
- layrnorm_elementwise_affine: bool = True,
+ elementwise_affine: bool = True,
):
super().__init__()
@@ -180,8 +180,8 @@ def __init__(
self.norm_q = None
self.norm_k = None
elif qk_norm == "layer_norm":
- self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=layrnorm_elementwise_affine)
- self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=layrnorm_elementwise_affine)
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
elif qk_norm == "fp32_layer_norm":
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index 8d7e37f0c925..dcabb84f8b40 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -13,13 +13,12 @@
# limitations under the License.
-from typing import Any, Dict, Union
+from typing import Any, Dict, Optional, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
-from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention import FeedForward
from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
from ...models.modeling_utils import ModelMixin
@@ -56,7 +55,7 @@ def __init__(
out_dim=dim,
bias=True,
qk_norm="layer_norm",
- layrnorm_elementwise_affine=False,
+ elementwise_affine=False,
eps=1e-6,
)
@@ -68,12 +67,12 @@ def __init__(
def forward(
self,
hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
emb: torch.Tensor,
- text_length: int,
) -> torch.Tensor:
- encoder_hidden_states, hidden_states = hidden_states[:, :text_length], hidden_states[:, text_length:]
+ text_seq_length = encoder_hidden_states.size(1)
- # norm1
+ # norm & modulate
(
norm_hidden_states,
gate_msa,
@@ -87,40 +86,56 @@ def forward(
c_gate_mlp,
) = self.norm1(hidden_states, encoder_hidden_states, emb)
- # Attention
+ # attention
attn_input = torch.cat((norm_encoder_hidden_states, norm_hidden_states), dim=1)
attn_output = self.attn(hidden_states=attn_input)
- context_attn_output, attn_output = attn_output[:, :text_length], attn_output[:, text_length:]
+ context_attn_output, attn_output = attn_output[:, :text_seq_length], attn_output[:, text_seq_length:]
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
+ # norm & modulate
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ # context norm
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+ # feed-forward
norm_hidden_states = torch.cat((norm_encoder_hidden_states, norm_hidden_states), dim=1)
-
ff_output = self.ff(norm_hidden_states)
- context_ff_output, ff_output = ff_output[:, :text_length], ff_output[:, text_length:]
-
- encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
- hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
- hidden_states = torch.cat((encoder_hidden_states, hidden_states), dim=1)
-
- return hidden_states
-
-
-class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
- """
- The Transformer model introduced in CogView3.
-
- Reference: https://arxiv.org/abs/2403.05121
+ encoder_hidden_states = (encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length],)
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]
+
+ return hidden_states, encoder_hidden_states
+
+
+class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
+ r"""
+ The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay
+ Diffusion](https://huggingface.co/papers/2403.05121).
+
+ Args:
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ num_layers (`int`, defaults to `30`):
+ The number of layers of Transformer blocks to use.
+ attention_head_dim (`int`, defaults to `40`):
+ The number of channels in each head.
+ num_attention_heads (`int`, defaults to `64`):
+ The number of heads to use for multi-head attention.
+ out_channels (`int`, *optional*, defaults to `16`):
+ The number of channels in the output.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ time_embed_dim (`int`, defaults to `512`):
+ Output dimension of timestep embeddings.
"""
_supports_gradient_checkpointing = True
@@ -128,33 +143,32 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
@register_to_config
def __init__(
self,
- sample_size: int = 128,
patch_size: int = 2,
in_channels: int = 16,
num_layers: int = 30,
attention_head_dim: int = 40,
num_attention_heads: int = 64,
out_channels: int = 16,
- encoder_hidden_states_dim: int = 4096,
+ text_embed_dim: int = 4096,
+ time_embed_dim: int = 512,
pooled_projection_dim: int = 1536,
pos_embed_max_size: int = 128,
- time_embed_dim: int = 512,
):
super().__init__()
self.out_channels = out_channels
- self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
+ self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = CogView3PlusPatchEmbed(
- in_channels=self.config.in_channels,
+ in_channels=in_channels,
hidden_size=self.inner_dim,
- patch_size=self.config.patch_size,
- text_hidden_size=self.config.encoder_hidden_states_dim,
- pos_embed_max_size=self.config.pos_embed_max_size,
+ patch_size=patch_size,
+ text_hidden_size=text_embed_dim,
+ pos_embed_max_size=pos_embed_max_size,
)
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
- embedding_dim=self.config.time_embed_dim,
- pooled_projection_dim=self.config.pooled_projection_dim,
+ embedding_dim=time_embed_dim,
+ pooled_projection_dim=pooled_projection_dim,
timesteps_dim=self.inner_dim,
)
@@ -162,17 +176,17 @@ def __init__(
[
CogView3PlusTransformerBlock(
dim=self.inner_dim,
- num_attention_heads=self.config.num_attention_heads,
- attention_head_dim=self.config.attention_head_dim,
- time_embed_dim=self.config.time_embed_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ time_embed_dim=time_embed_dim,
)
- for _ in range(self.config.num_layers)
+ for _ in range(num_layers)
]
)
self.norm_out = AdaLayerNormContinuous(
embedding_dim=self.inner_dim,
- conditioning_embedding_dim=self.config.time_embed_dim,
+ conditioning_embedding_dim=time_embed_dim,
elementwise_affine=False,
eps=1e-6,
)
@@ -286,17 +300,17 @@ def _set_gradient_checkpointing(self, module, value=False):
def forward(
self,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- pooled_projections: torch.FloatTensor = None,
- timestep: torch.LongTensor = None,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ pooled_projections: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
return_dict: bool = True,
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`CogView3PlusTransformer2DModel`] forward method.
Args:
- hidden_states (`torch.FloatTensor`): Input `hidden_states`.
+ hidden_states (`torch.Tensor`): Input `hidden_states`.
timestep (`torch.LongTensor`): Indicates denoising step.
y (`torch.LongTensor`, *optional*): 标签输入,用于获取标签嵌入。
block_controlnet_hidden_states: (`list` of `torch.Tensor`): A list of tensors for residuals.
@@ -308,43 +322,43 @@ def forward(
"""
height, width = hidden_states.shape[-2:]
- text_length = encoder_hidden_states.shape[1]
+ text_seq_length = encoder_hidden_states.shape[1]
hidden_states = self.pos_embed(
hidden_states, encoder_hidden_states
) # takes care of adding positional embeddings too.
emb = self.time_text_embed(timestep, pooled_projections)
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
+ hidden_states = hidden_states[:, text_seq_length:]
+
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
- def create_custom_forward(module, return_dict=None):
+ def create_custom_forward(module):
def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
+ return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
+ encoder_hidden_states,
emb,
- text_length,
**ckpt_kwargs,
)
else:
- hidden_states = block(
+ hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
emb=emb,
- text_length=text_length,
)
- hidden_states = hidden_states[:, text_length:]
hidden_states = self.norm_out(hidden_states, emb)
hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels)
+
# unpatchify
patch_size = self.config.patch_size
height = height // patch_size
diff --git a/tests/models/transformers/test_models_transformer_cogview3plus.py b/tests/models/transformers/test_models_transformer_cogview3plus.py
new file mode 100644
index 000000000000..a82417b2f669
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_cogview3plus.py
@@ -0,0 +1,80 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import CogView3PlusTransformer2DModel
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = CogView3PlusTransformer2DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_channels = 4
+ height = 8
+ width = 8
+ embedding_dim = 8
+ sequence_length = 8
+
+ hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ pooled_projections = torch.randn((batch_size, embedding_dim)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "pooled_projections": pooled_projections,
+ "timestep": timestep,
+ }
+
+ @property
+ def input_shape(self):
+ return (1, 4, 8, 8)
+
+ @property
+ def output_shape(self):
+ return (1, 4, 8, 8)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": 2,
+ "in_channels": 4,
+ "num_layers": 1,
+ "attention_head_dim": 4,
+ "num_attention_heads": 2,
+ "out_channels": 4,
+ "text_embed_dim": 8,
+ "time_embed_dim": 8,
+ "pooled_projection_dim": 8,
+ "pos_embed_max_size": 8,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
From 53935c049e855ce15b784a28fd642275345b128b Mon Sep 17 00:00:00 2001
From: Aryan
Date: Mon, 7 Oct 2024 01:16:46 +0200
Subject: [PATCH 04/30] make fix-copies
---
src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++
.../dummy_torch_and_transformers_objects.py | 16 ++++++++--------
2 files changed, 23 insertions(+), 8 deletions(-)
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 1ab946ce7257..eaab67c93b18 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -122,6 +122,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class CogView3PlusTransformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class ConsistencyDecoderVAE(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 69ce91da0f25..0a8d3c872381 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -317,7 +317,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class CycleDiffusionPipeline(metaclass=DummyObject):
+class CogView3PlusPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -332,7 +332,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxControlNetImg2ImgPipeline(metaclass=DummyObject):
+class CycleDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -347,7 +347,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxControlNetInpaintPipeline(metaclass=DummyObject):
+class FluxControlNetImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -362,7 +362,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxControlNetPipeline(metaclass=DummyObject):
+class FluxControlNetInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -377,7 +377,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxImg2ImgPipeline(metaclass=DummyObject):
+class FluxControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -392,7 +392,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxInpaintPipeline(metaclass=DummyObject):
+class FluxImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -407,7 +407,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxPipeline(metaclass=DummyObject):
+class FluxInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -422,7 +422,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class CogView3PlusPipeline(metaclass=DummyObject):
+class FluxPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
From 56b5599baa24f8de9a014a314c5bfbc70214fd86 Mon Sep 17 00:00:00 2001
From: zR <2448370773@qq.com>
Date: Mon, 7 Oct 2024 23:30:24 +0800
Subject: [PATCH 05/30] Update convert_cogview3_to_diffusers.py
---
scripts/convert_cogview3_to_diffusers.py | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/scripts/convert_cogview3_to_diffusers.py b/scripts/convert_cogview3_to_diffusers.py
index d1a7fd51654e..8191cc1feaad 100644
--- a/scripts/convert_cogview3_to_diffusers.py
+++ b/scripts/convert_cogview3_to_diffusers.py
@@ -38,11 +38,10 @@
from accelerate import init_empty_weights
from transformers import T5EncoderModel, T5Tokenizer
-from diffusers import AutoencoderKL, CogView3PlusPipeline, CogView3PlusTransformer2DModel, DDIMScheduler
+from diffusers import AutoencoderKL, CogView3PlusPipeline, CogView3PlusTransformer2DModel, CogVideoXDDIMScheduler
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.utils.import_utils import is_accelerate_available
-
CTX = init_empty_weights if is_accelerate_available else nullcontext
TOKENIZER_MAX_LENGTH = 224
@@ -202,10 +201,10 @@ def main(args):
for param in text_encoder.parameters():
param.data = param.data.contiguous()
- # TODO: figure out the correct scheduler
- scheduler = DDIMScheduler.from_config(
+ # TODO: figure out the correct scheduler if it is same as CogVideoXDDIMScheduler
+ scheduler = CogVideoXDDIMScheduler.from_config(
{
- "snr_shift_scale": 4.0, # This is different from default
+ "snr_shift_scale": 4.0,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
From 8e2ddd56faea82e28c99065c368f73421ef16f11 Mon Sep 17 00:00:00 2001
From: Aryan
Date: Tue, 8 Oct 2024 15:59:39 +0200
Subject: [PATCH 06/30] initial pipeline draft
---
scripts/convert_cogview3_to_diffusers.py | 27 +-
src/diffusers/models/embeddings.py | 45 +-
.../transformers/transformer_cogview3plus.py | 22 +-
.../cogview3/pipeline_cogview3plus.py | 733 +++++++++---------
4 files changed, 434 insertions(+), 393 deletions(-)
diff --git a/scripts/convert_cogview3_to_diffusers.py b/scripts/convert_cogview3_to_diffusers.py
index 8191cc1feaad..c70b158e6d32 100644
--- a/scripts/convert_cogview3_to_diffusers.py
+++ b/scripts/convert_cogview3_to_diffusers.py
@@ -78,19 +78,19 @@ def convert_cogview3_transformer_checkpoint_to_diffusers(ckpt_path):
new_state_dict["pos_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
new_state_dict["pos_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
- # Convert time_text_embed
- new_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
+ # Convert time_condition_embed
+ new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
"time_embed.0.weight"
)
- new_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("time_embed.0.bias")
- new_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
+ new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("time_embed.0.bias")
+ new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
"time_embed.2.weight"
)
- new_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_embed.2.bias")
- new_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop("label_emb.0.0.weight")
- new_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop("label_emb.0.0.bias")
- new_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop("label_emb.0.2.weight")
- new_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop("label_emb.0.2.bias")
+ new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_embed.2.bias")
+ new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop("label_emb.0.0.weight")
+ new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop("label_emb.0.0.bias")
+ new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop("label_emb.0.2.weight")
+ new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop("label_emb.0.2.bias")
# Convert transformer blocks
for i in range(30):
@@ -150,6 +150,8 @@ def convert_cogview3_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
def main(args):
+ if args.dtype is None:
+ dtype = None
if args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
@@ -168,7 +170,9 @@ def main(args):
)
transformer = CogView3PlusTransformer2DModel()
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
- transformer = transformer.to(dtype=dtype)
+ if dtype is not None:
+ # Original checkpoint data type will be preserved
+ transformer = transformer.to(dtype=dtype)
if args.vae_checkpoint_path is not None:
vae_config = {
@@ -191,7 +195,8 @@ def main(args):
converted_vae_state_dict = convert_cogview3_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_state_dict, strict=True)
- vae = vae.to(dtype=dtype)
+ if dtype is not None:
+ vae = vae.to(dtype=dtype)
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index cddf46cfea77..3cf22dfcfc4f 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -1070,27 +1070,6 @@ def forward(self, image_embeds: torch.Tensor):
return self.norm(x)
-class CogView3CombineTimestepLabelEmbedding(nn.Module):
- def __init__(self, time_embed_dim, label_embed_dim, in_channels=2560):
- super().__init__()
-
- self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=1)
- self.timestep_embedder = TimestepEmbedding(in_channels=in_channels, time_embed_dim=time_embed_dim)
- self.label_embedder = nn.Sequential(
- nn.Linear(label_embed_dim, time_embed_dim),
- nn.SiLU(),
- nn.Linear(time_embed_dim, time_embed_dim),
- )
-
- def forward(self, timestep, class_labels, hidden_dtype=None):
- t_proj = self.time_proj(timestep)
- t_emb = self.timestep_embedder(t_proj.to(dtype=hidden_dtype))
- label_emb = self.label_embedder(class_labels)
- emb = t_emb + label_emb
-
- return emb
-
-
class CombinedTimestepLabelEmbeddings(nn.Module):
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
super().__init__()
@@ -1153,6 +1132,30 @@ def forward(self, timestep, guidance, pooled_projection):
return conditioning
+class CogView3CombinedTimestepConditionEmbeddings(nn.Module):
+ def __init__(self, timestep_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim=256):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=timestep_dim)
+ self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, timestep_dim, act_fn="silu")
+
+ def forward(self, timestep: torch.Tensor, original_size: torch.Tensor, target_size: torch.Tensor, crop_coords: torch.Tensor, hidden_dtype: torch.dtype) -> torch.Tensor:
+ timesteps_proj = self.time_proj(timestep)
+
+ original_size_proj = self.condition_proj(original_size)
+ crop_coords_proj = self.condition_proj(crop_coords)
+ target_size_proj = self.condition_proj(target_size)
+ condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1) # (B, 3 * condition_dim)
+
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(hidden_dtype)) # (B, embedding_dim)
+ condition_emb = self.condition_embedder(condition_proj.to(hidden_dtype)) # (B, embedding_dim)
+
+ conditioning = timesteps_emb + condition_emb
+ return conditioning
+
+
class HunyuanDiTAttentionPool(nn.Module):
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index dcabb84f8b40..8fedd7c4e9c9 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -24,7 +24,7 @@
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...utils import is_torch_version, logging
-from ..embeddings import CogView3PlusPatchEmbed, CombinedTimestepTextProjEmbeddings
+from ..embeddings import CogView3PlusPatchEmbed, CogView3CombinedTimestepConditionEmbeddings
from ..modeling_outputs import Transformer2DModelOutput
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
@@ -108,7 +108,7 @@ def forward(
norm_hidden_states = torch.cat((norm_encoder_hidden_states, norm_hidden_states), dim=1)
ff_output = self.ff(norm_hidden_states)
- encoder_hidden_states = (encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length],)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]
return hidden_states, encoder_hidden_states
@@ -151,8 +151,10 @@ def __init__(
out_channels: int = 16,
text_embed_dim: int = 4096,
time_embed_dim: int = 512,
+ condition_dim: int = 512,
pooled_projection_dim: int = 1536,
pos_embed_max_size: int = 128,
+ sample_size: int = 128,
):
super().__init__()
self.out_channels = out_channels
@@ -166,8 +168,9 @@ def __init__(
pos_embed_max_size=pos_embed_max_size,
)
- self.time_text_embed = CombinedTimestepTextProjEmbeddings(
- embedding_dim=time_embed_dim,
+ self.time_condition_embed = CogView3CombinedTimestepConditionEmbeddings(
+ timestep_dim=time_embed_dim,
+ condition_dim=condition_dim,
pooled_projection_dim=pooled_projection_dim,
timesteps_dim=self.inner_dim,
)
@@ -301,9 +304,11 @@ def _set_gradient_checkpointing(self, module, value=False):
def forward(
self,
hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- pooled_projections: Optional[torch.Tensor] = None,
- timestep: Optional[torch.LongTensor] = None,
+ encoder_hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ original_size: torch.Tensor,
+ target_size: torch.Tensor,
+ crop_coords: torch.Tensor,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
@@ -320,14 +325,13 @@ def forward(
Returns:
Output tensor or `Transformer2DModelOutput`.
"""
-
height, width = hidden_states.shape[-2:]
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = self.pos_embed(
hidden_states, encoder_hidden_states
) # takes care of adding positional embeddings too.
- emb = self.time_text_embed(timestep, pooled_projections)
+ emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
index 3413fcda9e85..c83b4a97951f 100644
--- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
+++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
@@ -1,4 +1,5 @@
-# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,68 +14,66 @@
# limitations under the License.
import inspect
-from typing import Any, Callable, Dict, List, Optional, Union
+import math
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
-import numpy as np
import torch
-from transformers import T5EncoderModel, T5TokenizerFast
+from transformers import T5EncoderModel, T5Tokenizer
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
-from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
-from ...models.autoencoders import AutoencoderKL
-from ...models.transformers.transformer_cogview3plus import CogView3PlusTransformer2DModel
-from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import (
- USE_PEFT_BACKEND,
- is_torch_xla_available,
- logging,
- replace_example_docstring,
- scale_lora_layers,
- unscale_lora_layers,
-)
+from ...models import AutoencoderKL, CogView3PlusTransformer2DModel
+from ...models.embeddings import get_3d_rotary_pos_embed
+from ...pipelines.pipeline_utils import DiffusionPipeline
+from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
+from ...utils import logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import CogView3PipelineOutput
-if is_torch_xla_available():
- import torch_xla.core.xla_model as xm
-
- XLA_AVAILABLE = True
-else:
- XLA_AVAILABLE = False
-
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
- ```py
+ ```python
>>> import torch
- >>> from diffusers import CogView3PlusPipeline
-
- >>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3-Plus-3B", torch_dtype=torch.bfloat16)
- >>> pipe.to("cuda")
- >>> prompt = "A cat holding a sign that says hello world"
- >>> # Depending on the variant being used, the pipeline call will slightly vary.
- >>> # Refer to the pipeline documentation for more details.
- >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
- >>> image.save("cat.png")
+ >>> from diffusers import CogVideoXPipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
+ >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
+ >>> prompt = (
+ ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
+ ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
+ ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
+ ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
+ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
+ ... "atmosphere of this unique musical performance."
+ ... )
+ >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
+ >>> export_to_video(video, "output.mp4", fps=8)
```
"""
-def calculate_shift(
- image_seq_len,
- base_seq_len: int = 256,
- max_seq_len: int = 4096,
- base_shift: float = 0.5,
- max_shift: float = 1.16,
-):
- m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
- b = base_shift - m * base_seq_len
- mu = image_seq_len * m + b
- return mu
+# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+ tw = tgt_width
+ th = tgt_height
+ h, w = src
+ r = h / w
+ if r > (th / tw):
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h))
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
@@ -137,62 +136,62 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class CogView3PlusPipeline(DiffusionPipeline, PeftAdapterMixin, FromOriginalModelMixin):
+class CogView3PlusPipeline(DiffusionPipeline):
r"""
- The CogView3 pipeline for text-to-image generation.
+ Pipeline for text-to-image generation using CogView3Plus.
- Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
- transformer ([`CogView3PlusTransformerBlock`]):
- Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
- scheduler ([`FlowMatchEulerDiscreteScheduler`]):
- A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
text_encoder ([`T5EncoderModel`]):
- [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
- the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
- tokenizer (`T5TokenizerFast`):
- Second Tokenizer of class
- [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ Frozen text-encoder. CogView3Plus uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`CogView3PlusTransformer2DModel`]):
+ A text conditioned `CogView3PlusTransformer2DModel` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
"""
- model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = []
- _callback_tensor_inputs = ["latents", "prompt_embeds"]
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
def __init__(
self,
- scheduler: FlowMatchEulerDiscreteScheduler,
- vae: AutoencoderKL,
+ tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
- tokenizer: T5TokenizerFast,
+ vae: AutoencoderKL,
transformer: CogView3PlusTransformer2DModel,
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
):
super().__init__()
self.register_modules(
- vae=vae,
- text_encoder=text_encoder,
- tokenizer=tokenizer,
- transformer=transformer,
- scheduler=scheduler,
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
- self.tokenizer_max_length = (
- self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
- )
- self.default_sample_size = 64
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
- num_images_per_prompt: int = 1,
- max_sequence_length: int = 256,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
@@ -207,102 +206,161 @@ def _get_t5_prompt_embeds(
padding="max_length",
max_length=max_sequence_length,
truncation=True,
- return_length=False,
- return_overflowing_tokens=False,
+ add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
- removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
- prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)[0].to(
- dtype=dtype, device=device
- )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
-
- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
device: Optional[torch.device] = None,
- num_images_per_prompt: int = 1,
- prompt_embeds: Optional[torch.FloatTensor] = None,
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
- max_sequence_length: int = 512,
- lora_scale: Optional[float] = None,
+ dtype: Optional[torch.dtype] = None,
):
r"""
+ Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
- device: (`torch.device`):
- torch device
- num_images_per_prompt (`int`):
- number of images that should be generated per prompt
- prompt_embeds (`torch.FloatTensor`, *optional*):
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
- lora_scale (`float`, *optional*):
- A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
"""
device = device or self._execution_device
- # set lora scale so that monkey patched LoRA
- # function of text encoder can correctly access it
- if lora_scale is not None and isinstance(self, PeftAdapterMixin):
- self._lora_scale = lora_scale
-
- # dynamically adjust the LoRA scale
- if self.text_encoder is not None and USE_PEFT_BACKEND:
- scale_lora_layers(self.text_encoder, lora_scale)
- if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
- scale_lora_layers(self.text_encoder_2, lora_scale)
-
prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
- num_images_per_prompt=num_images_per_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
+ dtype=dtype,
)
- if self.text_encoder is not None:
- if isinstance(self, PeftAdapterMixin) and USE_PEFT_BACKEND:
- # Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
- dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
- text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
- return prompt_embeds, pooled_prompt_embeds, text_ids
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
prompt_embeds=None,
- pooled_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
- max_sequence_length=None,
+ negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -313,128 +371,55 @@ def check_inputs(
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
-
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
- if prompt_embeds is not None and pooled_prompt_embeds is None:
+
+ if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
- "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
- if max_sequence_length is not None and max_sequence_length > 512:
- raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
-
- @staticmethod
- def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
- latent_image_ids = torch.zeros(height // 2, width // 2, 3)
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
-
- latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
-
- latent_image_ids = latent_image_ids.reshape(
- latent_image_id_height * latent_image_id_width, latent_image_id_channels
- )
-
- return latent_image_ids.to(device=device, dtype=dtype)
-
- @staticmethod
- def _pack_latents(latents, batch_size, num_channels_latents, height, width):
- latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
- latents = latents.permute(0, 2, 4, 1, 3, 5)
- latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
-
- return latents
-
- @staticmethod
- def _unpack_latents(latents, height, width, vae_scale_factor):
- batch_size, num_patches, channels = latents.shape
-
- height = height // vae_scale_factor
- width = width // vae_scale_factor
-
- latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
- latents = latents.permute(0, 3, 1, 4, 2, 5)
-
- latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
-
- return latents
-
- def enable_vae_slicing(self):
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.vae.enable_slicing()
-
- def disable_vae_slicing(self):
- r"""
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
- computing decoding in one step.
- """
- self.vae.disable_slicing()
-
- def enable_vae_tiling(self):
- r"""
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
- processing larger images.
- """
- self.vae.enable_tiling()
-
- def disable_vae_tiling(self):
- r"""
- Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
- computing decoding in one step.
- """
- self.vae.disable_tiling()
-
- def prepare_latents(
- self,
- batch_size,
- num_channels_latents,
- height,
- width,
- dtype,
- device,
- generator,
- latents=None,
- ):
- height = 2 * (int(height) // self.vae_scale_factor)
- width = 2 * (int(width) // self.vae_scale_factor)
-
- shape = (batch_size, num_channels_latents, height, width)
-
- if latents is not None:
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
- return latents.to(device=device, dtype=dtype), latent_image_ids
-
- if isinstance(generator, list) and len(generator) != batch_size:
+ if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
- latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
-
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
-
- return latents, latent_image_ids
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def fuse_qkv_projections(self) -> None:
+ r"""Enables fused QKV projections."""
+ self.fusing_transformer = True
+ self.transformer.fuse_qkv_projections()
+
+ def unfuse_qkv_projections(self) -> None:
+ r"""Disable QKV projection fusion if enabled."""
+ if not self.fusing_transformer:
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.transformer.unfuse_qkv_projections()
+ self.fusing_transformer = False
@property
def guidance_scale(self):
return self._guidance_scale
- @property
- def joint_attention_kwargs(self):
- return self._joint_attention_kwargs
-
@property
def num_timesteps(self):
return self._num_timesteps
@@ -447,42 +432,49 @@ def interrupt(self):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
- prompt: Union[str, List[str]] = None,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
- num_inference_steps: int = 28,
- timesteps: List[int] = None,
- guidance_scale: float = 3.5,
- num_images_per_prompt: Optional[int] = 1,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ guidance_scale: float = 6,
+ use_dynamic_cfg: bool = False,
+ num_images_per_prompt: int = 1,
+ eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
- output_type: Optional[str] = "pil",
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ output_type: str = "pil",
return_dict: bool = True,
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- max_sequence_length: int = 512,
- ):
- r"""
+ max_sequence_length: int = 224,
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
+ """
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
- instead.
- prompt_2 (`str` or `List[str]`, *optional*):
- The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
- will be used instead
- prompt_3 (`str` or `List[str]`, *optional*):
- The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
- will be used instead
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
- The height in pixels of the generated image. This is set to 1024 by default for the best results.
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
- The width in pixels of the generated image. This is set to 1024 by default for the best results.
- num_inference_steps (`int`, *optional*, defaults to 50):
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. If not provided, it is set to 1024.
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. If not provided it is set to 1024.
+ num_inference_steps (`int`, *optional*, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
@@ -495,17 +487,7 @@ def __call__(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
- negative_prompt (`str` or `List[str]`, *optional*):
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
- less than `1`).
- negative_prompt_2 (`str` or `List[str]`, *optional*):
- The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
- `text_encoder_2`. If not defined, `negative_prompt` is used instead
- negative_prompt_3 (`str` or `List[str]`, *optional*):
- The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
- `text_encoder_3`. If not defined, `negative_prompt` is used instead
- num_images_per_prompt (`int`, *optional*, defaults to 1):
+ num_images_per_prompt (`int`, *optional*, defaults to `1`):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
@@ -521,20 +503,42 @@ def __call__(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
- input argument.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
- joint_attention_kwargs (`dict`, *optional*):
+ attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -547,35 +551,41 @@ def __call__(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
- max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
+ max_sequence_length (`int`, defaults to `224`):
+ Maximum sequence length in encoded prompt. Must be consistent with
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
Examples:
Returns:
- [`~pipelines.cogview3.CogView3PlusPipelineOutput`] or `tuple`:
- [`~pipelines.cogview3.CogView3PlusPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
- returning a tuple, the first element is a list with the generated images.
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
"""
- height = height or self.default_sample_size * self.vae_scale_factor
- width = width or self.default_sample_size * self.vae_scale_factor
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
- prompt_embeds=prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
- max_sequence_length=max_sequence_length,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds,
+ negative_prompt_embeds,
)
-
self._guidance_scale = guidance_scale
- self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
- # 2. Define call parameters
+ # 2. Default call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
@@ -585,28 +595,34 @@ def __call__(
device = self._execution_device
- lora_scale = (
- self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
- )
- (
- prompt_embeds,
- pooled_prompt_embeds,
- text_ids,
- ) = self.encode_prompt(
- prompt=prompt,
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ negative_prompt,
+ do_classifier_free_guidance,
+ num_videos_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- device=device,
- num_images_per_prompt=num_images_per_prompt,
+ negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
- lora_scale=lora_scale,
+ device=device,
)
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
- # 4. Prepare latent variables
- num_channels_latents = self.transformer.config.in_channels // 4
- latents, latent_image_ids = self.prepare_latents(
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
batch_size * num_images_per_prompt,
- num_channels_latents,
+ latent_channels,
height,
width,
prompt_embeds.dtype,
@@ -615,64 +631,82 @@ def __call__(
latents,
)
- # 5. Prepare timesteps
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
- image_seq_len = latents.shape[1]
- mu = calculate_shift(
- image_seq_len,
- self.scheduler.config.base_image_seq_len,
- self.scheduler.config.max_image_seq_len,
- self.scheduler.config.base_shift,
- self.scheduler.config.max_shift,
- )
- timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler,
- num_inference_steps,
- device,
- timesteps,
- sigmas,
- mu=mu,
- )
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
- self._num_timesteps = len(timesteps)
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare additional timestep conditions
+ # TODO: Make this like SDXL
+ original_size = torch.tensor(original_size, dtype=prompt_embeds.dtype)
+ target_size = torch.tensor(target_size, dtype=prompt_embeds.dtype)
+ crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=prompt_embeds.dtype)
- # handle guidance
- if self.transformer.config.guidance_embeds:
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
- guidance = guidance.expand(latents.shape[0])
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_original_size = torch.tensor(negative_original_size, dtype=prompt_embeds.dtype)
+ negative_target_size = torch.tensor(negative_target_size, dtype=prompt_embeds.dtype)
+ negative_crops_coords_top_left = torch.tensor(negative_crops_coords_top_left, dtype=prompt_embeds.dtype)
else:
- guidance = None
+ negative_original_size = original_size
+ negative_target_size = target_size
+ negative_crops_coords_top_left = crops_coords_top_left
+
+ if do_classifier_free_guidance:
+ original_size = torch.cat([negative_original_size, original_size])
+ target_size = torch.cat([negative_target_size, target_size])
+ crops_coords_top_left = torch.cat([negative_crops_coords_top_left, crops_coords_top_left])
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
- # 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
+ # for DPM-solver++
+ old_pred_original_sample = None
for i, t in enumerate(timesteps):
if self.interrupt:
continue
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ timestep = t.expand(latent_model_input.shape[0])
+ # predict noise model_output
noise_pred = self.transformer(
- hidden_states=latents,
- timestep=timestep / 1000,
- guidance=guidance,
- pooled_projections=pooled_prompt_embeds,
+ hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
- txt_ids=text_ids,
- img_ids=latent_image_ids,
- joint_attention_kwargs=self.joint_attention_kwargs,
+ timestep=timestep,
+ original_size=original_size,
+ target_size=target_size,
+ crop_coords=crops_coords_top_left,
return_dict=False,
)[0]
+ noise_pred = noise_pred.float()
- # compute the previous noisy sample x_t -> x_t-1
- latents_dtype = latents.dtype
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+ # perform guidance
+ if use_dynamic_cfg:
+ self._guidance_scale = 1 + guidance_scale * (
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+ )
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
- if latents.dtype != latents_dtype:
- if torch.backends.mps.is_available():
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
- latents = latents.to(latents_dtype)
+ # compute the previous noisy sample x_t -> x_t-1
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ else:
+ latents, old_pred_original_sample = self.scheduler.step(
+ noise_pred,
+ old_pred_original_sample,
+ t,
+ timesteps[i - 1] if i > 0 else None,
+ latents,
+ **extra_step_kwargs,
+ return_dict=False,
+ )
+ latents = latents.to(prompt_embeds.dtype)
+ # call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
@@ -681,22 +715,17 @@ def __call__(
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
- # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
- if XLA_AVAILABLE:
- xm.mark_step()
-
- if output_type == "latent":
- image = latents
-
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
else:
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
- latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
- image = self.vae.decode(latents, return_dict=False)[0]
- image = self.image_processor.postprocess(image, output_type=output_type)
+ image = latents
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
From 0e33401e470ac3a943298974bfe0a9dbdc3d296d Mon Sep 17 00:00:00 2001
From: Aryan
Date: Tue, 8 Oct 2024 16:00:19 +0200
Subject: [PATCH 07/30] make style
---
scripts/convert_cogview3_to_diffusers.py | 27 ++++++++++++++-----
src/diffusers/models/embeddings.py | 17 +++++++++---
.../transformers/transformer_cogview3plus.py | 4 +--
.../cogview3/pipeline_cogview3plus.py | 19 ++++++-------
4 files changed, 45 insertions(+), 22 deletions(-)
diff --git a/scripts/convert_cogview3_to_diffusers.py b/scripts/convert_cogview3_to_diffusers.py
index c70b158e6d32..389bfc854737 100644
--- a/scripts/convert_cogview3_to_diffusers.py
+++ b/scripts/convert_cogview3_to_diffusers.py
@@ -38,10 +38,11 @@
from accelerate import init_empty_weights
from transformers import T5EncoderModel, T5Tokenizer
-from diffusers import AutoencoderKL, CogView3PlusPipeline, CogView3PlusTransformer2DModel, CogVideoXDDIMScheduler
+from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.utils.import_utils import is_accelerate_available
+
CTX = init_empty_weights if is_accelerate_available else nullcontext
TOKENIZER_MAX_LENGTH = 224
@@ -82,15 +83,27 @@ def convert_cogview3_transformer_checkpoint_to_diffusers(ckpt_path):
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
"time_embed.0.weight"
)
- new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("time_embed.0.bias")
+ new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
+ "time_embed.0.bias"
+ )
new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
"time_embed.2.weight"
)
- new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_embed.2.bias")
- new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop("label_emb.0.0.weight")
- new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop("label_emb.0.0.bias")
- new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop("label_emb.0.2.weight")
- new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop("label_emb.0.2.bias")
+ new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
+ "time_embed.2.bias"
+ )
+ new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop(
+ "label_emb.0.0.weight"
+ )
+ new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop(
+ "label_emb.0.0.bias"
+ )
+ new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop(
+ "label_emb.0.2.weight"
+ )
+ new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop(
+ "label_emb.0.2.bias"
+ )
# Convert transformer blocks
for i in range(30):
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 3cf22dfcfc4f..4f343ddf71e5 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -1141,14 +1141,23 @@ def __init__(self, timestep_dim: int, condition_dim: int, pooled_projection_dim:
self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=timestep_dim)
self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, timestep_dim, act_fn="silu")
- def forward(self, timestep: torch.Tensor, original_size: torch.Tensor, target_size: torch.Tensor, crop_coords: torch.Tensor, hidden_dtype: torch.dtype) -> torch.Tensor:
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ original_size: torch.Tensor,
+ target_size: torch.Tensor,
+ crop_coords: torch.Tensor,
+ hidden_dtype: torch.dtype,
+ ) -> torch.Tensor:
timesteps_proj = self.time_proj(timestep)
-
+
original_size_proj = self.condition_proj(original_size)
crop_coords_proj = self.condition_proj(crop_coords)
target_size_proj = self.condition_proj(target_size)
- condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1) # (B, 3 * condition_dim)
-
+ condition_proj = torch.cat(
+ [original_size_proj, crop_coords_proj, target_size_proj], dim=1
+ ) # (B, 3 * condition_dim)
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(hidden_dtype)) # (B, embedding_dim)
condition_emb = self.condition_embedder(condition_proj.to(hidden_dtype)) # (B, embedding_dim)
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index 8fedd7c4e9c9..8c9ef92adfda 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -13,7 +13,7 @@
# limitations under the License.
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, Union
import torch
import torch.nn as nn
@@ -24,7 +24,7 @@
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...utils import is_torch_version, logging
-from ..embeddings import CogView3PlusPatchEmbed, CogView3CombinedTimestepConditionEmbeddings
+from ..embeddings import CogView3CombinedTimestepConditionEmbeddings, CogView3PlusPatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
index c83b4a97951f..31f1578526d6 100644
--- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
+++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
@@ -15,7 +15,7 @@
import inspect
import math
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import T5EncoderModel, T5Tokenizer
@@ -23,7 +23,6 @@
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, CogView3PlusTransformer2DModel
-from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import logging, replace_example_docstring
@@ -459,7 +458,7 @@ def __call__(
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 224,
- ) -> Union[CogVideoXPipelineOutput, Tuple]:
+ ) -> Union[CogView3PipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@@ -558,17 +557,17 @@ def __call__(
Examples:
Returns:
- [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
- [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogView3PipelineOutput`] or `tuple`:
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogView3PipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
-
+
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
-
+
original_size = original_size or (height, width)
target_size = target_size or (height, width)
@@ -721,10 +720,12 @@ def __call__(
progress_bar.update()
if not output_type == "latent":
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+ 0
+ ]
else:
image = latents
-
+
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
From 3873c57b58840190ce421647ce390710e24051c7 Mon Sep 17 00:00:00 2001
From: Aryan
Date: Tue, 8 Oct 2024 16:54:47 +0200
Subject: [PATCH 08/30] =?UTF-8?q?fight=20bugs=20=F0=9F=90=9B=F0=9F=AA=B3?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/diffusers/models/embeddings.py | 11 +++--
.../transformers/transformer_cogview3plus.py | 2 +-
.../cogview3/pipeline_cogview3plus.py | 47 ++++++++++---------
3 files changed, 32 insertions(+), 28 deletions(-)
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 4f343ddf71e5..dc3599a1ec3f 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -1151,15 +1151,16 @@ def forward(
) -> torch.Tensor:
timesteps_proj = self.time_proj(timestep)
- original_size_proj = self.condition_proj(original_size)
- crop_coords_proj = self.condition_proj(crop_coords)
- target_size_proj = self.condition_proj(target_size)
+ original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1)
+ crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
+ target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
condition_proj = torch.cat(
[original_size_proj, crop_coords_proj, target_size_proj], dim=1
) # (B, 3 * condition_dim)
- timesteps_emb = self.timestep_embedder(timesteps_proj.to(hidden_dtype)) # (B, embedding_dim)
- condition_emb = self.condition_embedder(condition_proj.to(hidden_dtype)) # (B, embedding_dim)
+
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
+ condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
conditioning = timesteps_emb + condition_emb
return conditioning
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index 8c9ef92adfda..d6a023f5dde3 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -151,7 +151,7 @@ def __init__(
out_channels: int = 16,
text_embed_dim: int = 4096,
time_embed_dim: int = 512,
- condition_dim: int = 512,
+ condition_dim: int = 256,
pooled_projection_dim: int = 1536,
pos_embed_max_size: int = 128,
sample_size: int = 128,
diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
index 31f1578526d6..8608370789ce 100644
--- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
+++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
@@ -144,7 +144,7 @@ class CogView3PlusPipeline(DiffusionPipeline):
Args:
vae ([`AutoencoderKL`]):
- Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. CogView3Plus uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
@@ -153,9 +153,9 @@ class CogView3PlusPipeline(DiffusionPipeline):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`CogView3PlusTransformer2DModel`]):
- A text conditioned `CogView3PlusTransformer2DModel` to denoise the encoded video latents.
+ A text conditioned `CogView3PlusTransformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
- A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
_optional_components = []
@@ -189,7 +189,7 @@ def __init__(
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
- num_videos_per_prompt: int = 1,
+ num_images_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
@@ -223,8 +223,8 @@ def _get_t5_prompt_embeds(
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
- prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
- prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
@@ -233,7 +233,7 @@ def encode_prompt(
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
- num_videos_per_prompt: int = 1,
+ num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 226,
@@ -252,8 +252,8 @@ def encode_prompt(
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
- num_videos_per_prompt (`int`, *optional*, defaults to 1):
- Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ Number of images that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -277,7 +277,7 @@ def encode_prompt(
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
- num_videos_per_prompt=num_videos_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
@@ -301,7 +301,7 @@ def encode_prompt(
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
- num_videos_per_prompt=num_videos_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
@@ -551,14 +551,13 @@ def __call__(
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `224`):
- Maximum sequence length in encoded prompt. Must be consistent with
- `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
+ Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
Examples:
Returns:
- [`~pipelines.cogvideo.pipeline_cogvideox.CogView3PipelineOutput`] or `tuple`:
- [`~pipelines.cogvideo.pipeline_cogvideox.CogView3PipelineOutput`] if `return_dict` is True, otherwise a
+ [`~pipelines.cogview3.pipeline_cogview3plus.CogView3PipelineOutput`] or `tuple`:
+ [`~pipelines.cogview3.pipeline_cogview3plus.CogView3PipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
@@ -604,7 +603,7 @@ def __call__(
prompt,
negative_prompt,
do_classifier_free_guidance,
- num_videos_per_prompt=num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
@@ -635,14 +634,14 @@ def __call__(
# 7. Prepare additional timestep conditions
# TODO: Make this like SDXL
- original_size = torch.tensor(original_size, dtype=prompt_embeds.dtype)
- target_size = torch.tensor(target_size, dtype=prompt_embeds.dtype)
- crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=prompt_embeds.dtype)
+ original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype)
+ target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype)
+ crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype)
if negative_original_size is not None and negative_target_size is not None:
- negative_original_size = torch.tensor(negative_original_size, dtype=prompt_embeds.dtype)
- negative_target_size = torch.tensor(negative_target_size, dtype=prompt_embeds.dtype)
- negative_crops_coords_top_left = torch.tensor(negative_crops_coords_top_left, dtype=prompt_embeds.dtype)
+ negative_original_size = torch.tensor([negative_original_size], dtype=prompt_embeds.dtype)
+ negative_target_size = torch.tensor([negative_target_size], dtype=prompt_embeds.dtype)
+ negative_crops_coords_top_left = torch.tensor([negative_crops_coords_top_left], dtype=prompt_embeds.dtype)
else:
negative_original_size = original_size
negative_target_size = target_size
@@ -652,6 +651,10 @@ def __call__(
original_size = torch.cat([negative_original_size, original_size])
target_size = torch.cat([negative_target_size, target_size])
crops_coords_top_left = torch.cat([negative_crops_coords_top_left, crops_coords_top_left])
+
+ original_size = original_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
+ target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
+ crops_coords_top_left = crops_coords_top_left.to(device).repeat(batch_size * num_images_per_prompt, 1)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
From e77b98873b5a7262099c32043d0221856f254acd Mon Sep 17 00:00:00 2001
From: Aryan
Date: Tue, 8 Oct 2024 16:56:13 +0200
Subject: [PATCH 09/30] add example
---
.../cogview3/pipeline_cogview3plus.py | 23 +++++++------------
1 file changed, 8 insertions(+), 15 deletions(-)
diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
index 8608370789ce..1d96daf3726d 100644
--- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
+++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
@@ -37,21 +37,14 @@
Examples:
```python
>>> import torch
- >>> from diffusers import CogVideoXPipeline
- >>> from diffusers.utils import export_to_video
-
- >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
- >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
- >>> prompt = (
- ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
- ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
- ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
- ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
- ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
- ... "atmosphere of this unique musical performance."
- ... )
- >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
- >>> export_to_video(video, "output.mp4", fps=8)
+ >>> from diffusers import CogView3PlusPipeline
+
+ >>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3B", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ >>> image.save("output.png")
```
"""
From 958d3e743d74463c3915f224255d28d30710ed07 Mon Sep 17 00:00:00 2001
From: Aryan
Date: Tue, 8 Oct 2024 17:42:36 +0200
Subject: [PATCH 10/30] add tests; refactor
---
scripts/convert_cogview3_to_diffusers.py | 18 +-
src/diffusers/models/embeddings.py | 7 +-
.../transformers/transformer_cogview3plus.py | 24 +-
.../test_models_transformer_cogview3plus.py | 12 +-
tests/pipelines/cogview3/__init__.py | 0
tests/pipelines/cogview3/test_cogview3plus.py | 310 ++++++++++++++++++
6 files changed, 341 insertions(+), 30 deletions(-)
create mode 100644 tests/pipelines/cogview3/__init__.py
create mode 100644 tests/pipelines/cogview3/test_cogview3plus.py
diff --git a/scripts/convert_cogview3_to_diffusers.py b/scripts/convert_cogview3_to_diffusers.py
index 389bfc854737..632f5c4bc7f4 100644
--- a/scripts/convert_cogview3_to_diffusers.py
+++ b/scripts/convert_cogview3_to_diffusers.py
@@ -119,17 +119,17 @@ def convert_cogview3_transformer_checkpoint_to_diffusers(ckpt_path):
q, k, v = qkv_weight.chunk(3, dim=0)
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
- new_state_dict[block_prefix + "attn.to_q.weight"] = q
- new_state_dict[block_prefix + "attn.to_q.bias"] = q_bias
- new_state_dict[block_prefix + "attn.to_k.weight"] = k
- new_state_dict[block_prefix + "attn.to_k.bias"] = k_bias
- new_state_dict[block_prefix + "attn.to_v.weight"] = v
- new_state_dict[block_prefix + "attn.to_v.bias"] = v_bias
-
- new_state_dict[block_prefix + "attn.to_out.0.weight"] = original_state_dict.pop(
+ new_state_dict[block_prefix + "attn1.to_q.weight"] = q
+ new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
+ new_state_dict[block_prefix + "attn1.to_k.weight"] = k
+ new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
+ new_state_dict[block_prefix + "attn1.to_v.weight"] = v
+ new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
+
+ new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
old_prefix + "attention.dense.weight"
)
- new_state_dict[block_prefix + "attn.to_out.0.bias"] = original_state_dict.pop(
+ new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(
old_prefix + "attention.dense.bias"
)
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index dc3599a1ec3f..ea809b3c0302 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -1133,13 +1133,13 @@ def forward(self, timestep, guidance, pooled_projection):
class CogView3CombinedTimestepConditionEmbeddings(nn.Module):
- def __init__(self, timestep_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim=256):
+ def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
super().__init__()
self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
- self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=timestep_dim)
- self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, timestep_dim, act_fn="silu")
+ self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
+ self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(
self,
@@ -1158,7 +1158,6 @@ def forward(
[original_size_proj, crop_coords_proj, target_size_proj], dim=1
) # (B, 3 * condition_dim)
-
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index d6a023f5dde3..d02ada9dc572 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -20,7 +20,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.attention import FeedForward
-from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
+from ...models.attention_processor import Attention, AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...utils import is_torch_version, logging
@@ -48,7 +48,7 @@ def __init__(
self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim)
- self.attn = Attention(
+ self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
@@ -57,6 +57,7 @@ def __init__(
qk_norm="layer_norm",
elementwise_affine=False,
eps=1e-6,
+ processor=CogVideoXAttnProcessor2_0(),
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
@@ -87,29 +88,24 @@ def forward(
) = self.norm1(hidden_states, encoder_hidden_states, emb)
# attention
- attn_input = torch.cat((norm_encoder_hidden_states, norm_hidden_states), dim=1)
- attn_output = self.attn(hidden_states=attn_input)
- context_attn_output, attn_output = attn_output[:, :text_seq_length], attn_output[:, text_seq_length:]
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states)
- attn_output = gate_msa.unsqueeze(1) * attn_output
- hidden_states = hidden_states + attn_output
+ hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
# norm & modulate
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
- # context norm
- context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
- encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
# feed-forward
- norm_hidden_states = torch.cat((norm_encoder_hidden_states, norm_hidden_states), dim=1)
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_output = self.ff(norm_hidden_states)
- encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]
return hidden_states, encoder_hidden_states
@@ -169,7 +165,7 @@ def __init__(
)
self.time_condition_embed = CogView3CombinedTimestepConditionEmbeddings(
- timestep_dim=time_embed_dim,
+ embedding_dim=time_embed_dim,
condition_dim=condition_dim,
pooled_projection_dim=pooled_projection_dim,
timesteps_dim=self.inner_dim,
@@ -281,7 +277,7 @@ def fuse_qkv_projections(self):
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
- self.set_attn_processor(FusedJointAttnProcessor2_0())
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
diff --git a/tests/models/transformers/test_models_transformer_cogview3plus.py b/tests/models/transformers/test_models_transformer_cogview3plus.py
index a82417b2f669..0212b95ba130 100644
--- a/tests/models/transformers/test_models_transformer_cogview3plus.py
+++ b/tests/models/transformers/test_models_transformer_cogview3plus.py
@@ -45,13 +45,17 @@ def dummy_input(self):
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
- pooled_projections = torch.randn((batch_size, embedding_dim)).to(torch_device)
+ original_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
+ target_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
+ crop_coords = torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
- "pooled_projections": pooled_projections,
+ "original_size": original_size,
+ "target_size": target_size,
+ "crop_coords": crop_coords,
"timestep": timestep,
}
@@ -73,8 +77,10 @@ def prepare_init_args_and_inputs_for_common(self):
"out_channels": 4,
"text_embed_dim": 8,
"time_embed_dim": 8,
- "pooled_projection_dim": 8,
+ "condition_dim": 2,
+ "pooled_projection_dim": 12,
"pos_embed_max_size": 8,
+ "sample_size": 8,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
diff --git a/tests/pipelines/cogview3/__init__.py b/tests/pipelines/cogview3/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/cogview3/test_cogview3plus.py b/tests/pipelines/cogview3/test_cogview3plus.py
new file mode 100644
index 000000000000..a0036b571bb0
--- /dev/null
+++ b/tests/pipelines/cogview3/test_cogview3plus.py
@@ -0,0 +1,310 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKL, CogView3PlusTransformer2DModel, CogView3PlusPipeline, CogVideoXDDIMScheduler
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ numpy_cosine_similarity_distance,
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+ check_qkv_fusion_matches_attn_procs_length,
+ check_qkv_fusion_processors_exist,
+ to_np,
+)
+
+
+enable_full_determinism()
+
+
+class CogView3PlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = CogView3PlusPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = CogView3PlusTransformer2DModel(
+ patch_size=2,
+ in_channels=4,
+ num_layers=1,
+ attention_head_dim=4,
+ num_attention_heads=2,
+ out_channels=4,
+ text_embed_dim=32, # Must match with tiny-random-t5
+ time_embed_dim=8,
+ condition_dim=2,
+ pooled_projection_dim=12,
+ pos_embed_max_size=8,
+ sample_size=8,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ sample_size=128,
+ )
+
+ torch.manual_seed(0)
+ scheduler = CogVideoXDDIMScheduler()
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs)[0]
+ generated_image = image[0]
+
+ self.assertEqual(generated_image.shape, (3, 16, 16))
+ expected_image = torch.randn(3, 16, 16)
+ max_diff = np.abs(generated_image - expected_image).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_fused_qkv_projections(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ images = pipe(**inputs)[0] # [B, C, H, W]
+ original_image_slice = images[0, -1, -3:, -3:]
+
+ pipe.fuse_qkv_projections()
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_matches_attn_procs_length(
+ pipe.transformer, pipe.transformer.original_attn_processors
+ ), "Something wrong with the attention processors concerning the fused QKV projections."
+
+ inputs = self.get_dummy_inputs(device)
+ images = pipe(**inputs)[0]
+ image_slice_fused = images[0, -1, -3:, -3:]
+
+ pipe.transformer.unfuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ images = pipe(**inputs)[0]
+ image_slice_disabled = images[0, -1, -3:, -3:]
+
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
+
+
+@slow
+@require_torch_gpu
+class CogView3PlusPipelineIntegrationTests(unittest.TestCase):
+ prompt = "A painting of a squirrel eating a burger."
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_cogview3plus(self):
+ generator = torch.Generator("cpu").manual_seed(0)
+
+ pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3b", torch_dtype=torch.float16)
+ pipe.enable_model_cpu_offload()
+ prompt = self.prompt
+
+ images = pipe(
+ prompt=prompt,
+ height=1024,
+ width=1024,
+ generator=generator,
+ num_inference_steps=2,
+ output_type="np",
+ )[0]
+
+ image = images[0]
+ expected_image = torch.randn(1, 1024, 1024, 3).numpy()
+
+ max_diff = numpy_cosine_similarity_distance(image, expected_image)
+ assert max_diff < 1e-3, f"Max diff is too high. got {image}"
From ab8c65f9e947ff2fa89c5dcec17b80a13db0b81f Mon Sep 17 00:00:00 2001
From: Aryan
Date: Tue, 8 Oct 2024 17:42:53 +0200
Subject: [PATCH 11/30] make style
---
.../models/transformers/transformer_cogview3plus.py | 11 +++++++++--
.../pipelines/cogview3/pipeline_cogview3plus.py | 2 +-
tests/pipelines/cogview3/test_cogview3plus.py | 2 +-
3 files changed, 11 insertions(+), 4 deletions(-)
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index d02ada9dc572..7d587651af42 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -20,7 +20,12 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.attention import FeedForward
-from ...models.attention_processor import Attention, AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
+from ...models.attention_processor import (
+ Attention,
+ AttentionProcessor,
+ CogVideoXAttnProcessor2_0,
+ FusedCogVideoXAttnProcessor2_0,
+)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...utils import is_torch_version, logging
@@ -88,7 +93,9 @@ def forward(
) = self.norm1(hidden_states, encoder_hidden_states, emb)
# attention
- attn_hidden_states, attn_encoder_hidden_states = self.attn1(hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states)
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
+ )
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
index 1d96daf3726d..8f2452a48805 100644
--- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
+++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
@@ -644,7 +644,7 @@ def __call__(
original_size = torch.cat([negative_original_size, original_size])
target_size = torch.cat([negative_target_size, target_size])
crops_coords_top_left = torch.cat([negative_crops_coords_top_left, crops_coords_top_left])
-
+
original_size = original_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
crops_coords_top_left = crops_coords_top_left.to(device).repeat(batch_size * num_images_per_prompt, 1)
diff --git a/tests/pipelines/cogview3/test_cogview3plus.py b/tests/pipelines/cogview3/test_cogview3plus.py
index a0036b571bb0..433bdbd90728 100644
--- a/tests/pipelines/cogview3/test_cogview3plus.py
+++ b/tests/pipelines/cogview3/test_cogview3plus.py
@@ -20,7 +20,7 @@
import torch
from transformers import AutoTokenizer, T5EncoderModel
-from diffusers import AutoencoderKL, CogView3PlusTransformer2DModel, CogView3PlusPipeline, CogVideoXDDIMScheduler
+from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
From 059812ac0b19284dda7049fa3d21865cacda9e2c Mon Sep 17 00:00:00 2001
From: Aryan
Date: Tue, 8 Oct 2024 17:47:25 +0200
Subject: [PATCH 12/30] make fix-copies
---
src/diffusers/models/transformers/transformer_cogview3plus.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index 7d587651af42..38ddd6efa011 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -260,7 +260,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
From 3138ad19b5281bac0093861aefc22be828a6a548 Mon Sep 17 00:00:00 2001
From: Aryan
Date: Tue, 8 Oct 2024 17:52:22 +0200
Subject: [PATCH 13/30] add co-author
YiYi Xu
From 86909dc3c5df79604099c3e76be1934233c6c968 Mon Sep 17 00:00:00 2001
From: Aryan
Date: Tue, 8 Oct 2024 17:52:38 +0200
Subject: [PATCH 14/30] remove files
---
show_model.py | 91 -------------------------------------------
show_model_cogview.py | 25 ------------
2 files changed, 116 deletions(-)
delete mode 100644 show_model.py
delete mode 100644 show_model_cogview.py
diff --git a/show_model.py b/show_model.py
deleted file mode 100644
index 0127243117bd..000000000000
--- a/show_model.py
+++ /dev/null
@@ -1,91 +0,0 @@
-import torch
-from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
-from diffusers import AutoencoderKL
-from huggingface_hub import hf_hub_download
-from sgm.models.autoencoder import AutoencodingEngine
-
-# (1) create vae_sat
-# AutoencodingEngine initialization arguments:
-encoder_config={'target': 'sgm.modules.diffusionmodules.model.Encoder', 'params': {'attn_type': 'vanilla', 'double_z': True, 'z_channels': 16, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 4, 8, 8], 'num_res_blocks': 3, 'attn_resolutions': [], 'mid_attn': False, 'dropout': 0.0}}
-decoder_config={'target': 'sgm.modules.diffusionmodules.model.Decoder', 'params': {'attn_type': 'vanilla', 'double_z': True, 'z_channels': 16, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 4, 8, 8], 'num_res_blocks': 3, 'attn_resolutions': [], 'mid_attn': False, 'dropout': 0.0}}
-loss_config={'target': 'torch.nn.Identity'}
-regularizer_config={'target': 'sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer'}
-optimizer_config=None
-lr_g_factor=1.0
-ckpt_path="/raid/.cache/huggingface/models--ZP2HF--CogView3-SAT/snapshots/ca86ce9ba94f9a7f2dd109e7a59e4c8ad04121be/3plus_ae/imagekl_ch16.pt"
-ignore_keys= []
-kwargs = {"monitor": "val/rec_loss"}
-vae_sat = AutoencodingEngine(
- encoder_config=encoder_config,
- decoder_config=decoder_config,
- loss_config=loss_config,
- regularizer_config=regularizer_config,
- optimizer_config=optimizer_config,
- lr_g_factor=lr_g_factor,
- ckpt_path=ckpt_path,
- ignore_keys=ignore_keys,
- **kwargs)
-
-
-
-# (2) create vae (diffusers)
-ckpt_path_vae_cogview3 = hf_hub_download(repo_id="ZP2HF/CogView3-SAT", subfolder="3plus_ae", filename="imagekl_ch16.pt")
-cogview3_ckpt = torch.load(ckpt_path_vae_cogview3, map_location='cpu')["state_dict"]
-
-in_channels = 3 # Inferred from encoder.conv_in.weight shape
-out_channels = 3 # Inferred from decoder.conv_out.weight shape
-down_block_types = ("DownEncoderBlock2D",) * 4 # Inferred from the presence of 4 encoder.down blocks
-up_block_types = ("UpDecoderBlock2D",) * 4 # Inferred from the presence of 4 decoder.up blocks
-block_out_channels = (128, 512, 1024, 1024) # Inferred from the channel sizes in encoder.down blocks
-layers_per_block = 3 # Inferred from the number of blocks in each encoder.down and decoder.up
-act_fn = "silu" # This is the default, cannot be inferred from state_dict
-latent_channels = 16 # Inferred from decoder.conv_in.weight shape
-norm_num_groups = 32 # This is the default, cannot be inferred from state_dict
-sample_size = 1024 # This is the default, cannot be inferred from state_dict
-scaling_factor = 0.18215 # This is the default, cannot be inferred from state_dict
-force_upcast = True # This is the default, cannot be inferred from state_dict
-use_quant_conv = False # Inferred from the presence of encoder.conv_out
-use_post_quant_conv = False # Inferred from the presence of decoder.conv_in
-mid_block_add_attention = False # Inferred from the absence of attention layers in mid blocks
-
-vae = AutoencoderKL(
- in_channels=in_channels,
- out_channels=out_channels,
- down_block_types=down_block_types,
- up_block_types=up_block_types,
- block_out_channels=block_out_channels,
- layers_per_block=layers_per_block,
- act_fn=act_fn,
- latent_channels=latent_channels,
- norm_num_groups=norm_num_groups,
- sample_size=sample_size,
- scaling_factor=scaling_factor,
- force_upcast=force_upcast,
- use_quant_conv=use_quant_conv,
- use_post_quant_conv=use_post_quant_conv,
- mid_block_add_attention=mid_block_add_attention,
-)
-
-vae.eval()
-vae_sat.eval()
-
-converted_vae_state_dict = convert_ldm_vae_checkpoint(cogview3_ckpt, vae.config)
-vae.load_state_dict(converted_vae_state_dict, strict=False)
-
-# (3) run forward pass for both models
-
-# [2, 16, 128, 128] -> [2, 3, 1024, 1024
-z = torch.load("z.pt").float().to("cpu")
-
-with torch.no_grad():
- print(" ")
- print(f" running forward pass for diffusers vae")
- out = vae.decode(z).sample
- print(f" ")
- print(f" running forward pass for sgm vae")
- out_sat = vae_sat.decode(z)
-
-print(f" output shape: {out.shape}")
-print(f" expected output shape: {out_sat.shape}")
-assert out.shape == out_sat.shape
-assert (out - out_sat).abs().max() < 1e-4, f"max diff: {(out - out_sat).abs().max()}"
\ No newline at end of file
diff --git a/show_model_cogview.py b/show_model_cogview.py
deleted file mode 100644
index 5314930cb127..000000000000
--- a/show_model_cogview.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import torch
-from diffusers import CogView3PlusTransformer2DModel
-
-model = CogView3PlusTransformer2DModel.from_pretrained("/share/home/zyx/Models/CogView3Plus_hf/transformer",torch_dtype=torch.bfloat16)
-
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-model.to(device)
-
-batch_size = 1
-hidden_states = torch.ones((batch_size, 16, 256, 256), device=device, dtype=torch.bfloat16)
-timestep = torch.full((batch_size,), 999.0, device=device, dtype=torch.bfloat16)
-y = torch.ones((batch_size, 1536), device=device, dtype=torch.bfloat16)
-
-# 模拟调用 forward 方法
-outputs = model(
- hidden_states=hidden_states, # hidden_states 输入
- timestep=timestep, # timestep 输入
- y=y, # 标签输入
- block_controlnet_hidden_states=None, # 如果不需要,可以忽略
- return_dict=True, # 保持默认值
- target_size=[(2048, 2048)],
-)
-
-# 输出模型结果
-print("Output shape:", outputs.sample.shape)
\ No newline at end of file
From 7da234bc29f1ae1bf2be5f3c98dd44e0fba55c18 Mon Sep 17 00:00:00 2001
From: Aryan
Date: Tue, 8 Oct 2024 18:04:58 +0200
Subject: [PATCH 15/30] add docs
---
docs/source/en/_toctree.yml | 2 +
.../en/api/models/cogview3_transformer2d.md | 30 +++++++
docs/source/en/api/pipelines/cogview3.md | 79 +++++++++++++++++++
3 files changed, 111 insertions(+)
create mode 100644 docs/source/en/api/models/cogview3_transformer2d.md
create mode 100644 docs/source/en/api/pipelines/cogview3.md
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index b331e4b13760..2c332270d72a 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -242,6 +242,8 @@
title: AuraFlowTransformer2DModel
- local: api/models/cogvideox_transformer3d
title: CogVideoXTransformer3DModel
+ - local: api/models/transformer_cogview3plus
+ title: CogView3PlusTransformer2DModel
- local: api/models/dit_transformer2d
title: DiTTransformer2DModel
- local: api/models/flux_transformer
diff --git a/docs/source/en/api/models/cogview3_transformer2d.md b/docs/source/en/api/models/cogview3_transformer2d.md
new file mode 100644
index 000000000000..16f71a58cfb4
--- /dev/null
+++ b/docs/source/en/api/models/cogview3_transformer2d.md
@@ -0,0 +1,30 @@
+
+
+# CogView3PlusTransformer2DModel
+
+A Diffusion Transformer model for 2D data from [CogView3Plus](https://github.com/THUDM/CogView3) was introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) by Tsinghua University & ZhipuAI.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import CogView3PlusTransformer2DModel
+
+vae = CogView3PlusTransformer2DModel.from_pretrained("THUDM/CogView3Plus-3b", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
+```
+
+## CogView3PlusTransformer2DModel
+
+[[autodoc]] CogView3PlusTransformer2DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/pipelines/cogview3.md b/docs/source/en/api/pipelines/cogview3.md
new file mode 100644
index 000000000000..8a170193f3e9
--- /dev/null
+++ b/docs/source/en/api/pipelines/cogview3.md
@@ -0,0 +1,79 @@
+
+
+# CogVideoX
+
+[CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) from Tsinghua University & ZhipuAI, by Wendi Zheng, Jiayan Teng, Zhuoyi Yang, Weihan Wang, Jidong Chen, Xiaotao Gu, Yuxiao Dong, Ming Ding, Jie Tang.
+
+The abstract from the paper is:
+
+*Recent advancements in text-to-image generative systems have been largely driven by diffusion models. However, single-stage text-to-image diffusion models still face challenges, in terms of computational efficiency and the refinement of image details. To tackle the issue, we propose CogView3, an innovative cascaded framework that enhances the performance of text-to-image diffusion. CogView3 is the first model implementing relay diffusion in the realm of text-to-image generation, executing the task by first creating low-resolution images and subsequently applying relay-based super-resolution. This methodology not only results in competitive text-to-image outputs but also greatly reduces both training and inference costs. Our experimental results demonstrate that CogView3 outperforms SDXL, the current state-of-the-art open-source text-to-image diffusion model, by 77.0% in human evaluations, all while requiring only about 1/2 of the inference time. The distilled variant of CogView3 achieves comparable performance while only utilizing 1/10 of the inference time by SDXL.*
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
+
+## Inference
+
+Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
+
+First, load the pipeline:
+
+```python
+import torch
+from diffusers import CogView3PlusPipeline
+from diffusers.utils import export_to_video,load_image
+
+pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3b").to("cuda") # or "THUDM/CogVideoX-2b"
+```
+
+Then change the memory layout of the `transformer` and `vae` components to `torch.channels_last`:
+
+```python
+pipe.transformer.to(memory_format=torch.channels_last)
+pipe.vae.to(memory_format=torch.channels_last)
+```
+
+Compile the components and run inference:
+
+```python
+pipe.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
+pipe.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)
+
+# CogVideoX works well with long and well-described prompts
+prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
+video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
+```
+
+The [benchmark](TODO) results on an 80GB A100 machine are:
+
+```
+Without torch.compile(): Average inference time: TODO seconds.
+With torch.compile(): Average inference time: TODO seconds.
+```
+
+## CogView3PlusPipeline
+
+[[autodoc]] CogView3PlusPipeline
+ - all
+ - __call__
+
+## CogView3PipelineOutput
+
+[[autodoc]] pipelines.cogview3.pipeline_output.CogView3PipelineOutput
From 0114e3b9e4559dbad8958f225d73c789e6f12aec Mon Sep 17 00:00:00 2001
From: Aryan
Date: Tue, 8 Oct 2024 18:05:31 +0200
Subject: [PATCH 16/30] add co-author
Co-Authored-By: YiYi Xu
From 4e8de65586e62652b2025d195c6a5154163e9ada Mon Sep 17 00:00:00 2001
From: Aryan
Date: Tue, 8 Oct 2024 18:13:35 +0200
Subject: [PATCH 17/30] fight docs
---
docs/source/en/_toctree.yml | 4 +++-
...ogview3_transformer2d.md => cogview3plus_transformer2d.md} | 0
2 files changed, 3 insertions(+), 1 deletion(-)
rename docs/source/en/api/models/{cogview3_transformer2d.md => cogview3plus_transformer2d.md} (100%)
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 2c332270d72a..22613cb343ff 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -242,7 +242,7 @@
title: AuraFlowTransformer2DModel
- local: api/models/cogvideox_transformer3d
title: CogVideoXTransformer3DModel
- - local: api/models/transformer_cogview3plus
+ - local: api/models/cogview3plus_transformer2d
title: CogView3PlusTransformer2DModel
- local: api/models/dit_transformer2d
title: DiTTransformer2DModel
@@ -322,6 +322,8 @@
title: BLIP-Diffusion
- local: api/pipelines/cogvideox
title: CogVideoX
+ - local: api/pipelines/cogview3
+ title: CogView3
- local: api/pipelines/consistency_models
title: Consistency Models
- local: api/pipelines/controlnet
diff --git a/docs/source/en/api/models/cogview3_transformer2d.md b/docs/source/en/api/models/cogview3plus_transformer2d.md
similarity index 100%
rename from docs/source/en/api/models/cogview3_transformer2d.md
rename to docs/source/en/api/models/cogview3plus_transformer2d.md
From f35850c0d98952f072a16a177c3635b9a1d6a8a5 Mon Sep 17 00:00:00 2001
From: Aryan
Date: Wed, 9 Oct 2024 15:35:33 +0200
Subject: [PATCH 18/30] address reviews
---
docs/source/en/api/pipelines/cogview3.md | 41 +-----
src/diffusers/models/embeddings.py | 118 +++++++++---------
.../transformers/transformer_cogview3plus.py | 50 ++++++--
src/diffusers/pipelines/__init__.py | 2 +-
.../cogview3/pipeline_cogview3plus.py | 49 ++------
5 files changed, 110 insertions(+), 150 deletions(-)
diff --git a/docs/source/en/api/pipelines/cogview3.md b/docs/source/en/api/pipelines/cogview3.md
index 8a170193f3e9..85a9cf91736f 100644
--- a/docs/source/en/api/pipelines/cogview3.md
+++ b/docs/source/en/api/pipelines/cogview3.md
@@ -13,7 +13,7 @@
# limitations under the License.
-->
-# CogVideoX
+# CogView3Plus
[CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) from Tsinghua University & ZhipuAI, by Wendi Zheng, Jiayan Teng, Zhuoyi Yang, Weihan Wang, Jidong Chen, Xiaotao Gu, Yuxiao Dong, Ming Ding, Jie Tang.
@@ -29,45 +29,6 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
-## Inference
-
-Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
-
-First, load the pipeline:
-
-```python
-import torch
-from diffusers import CogView3PlusPipeline
-from diffusers.utils import export_to_video,load_image
-
-pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3b").to("cuda") # or "THUDM/CogVideoX-2b"
-```
-
-Then change the memory layout of the `transformer` and `vae` components to `torch.channels_last`:
-
-```python
-pipe.transformer.to(memory_format=torch.channels_last)
-pipe.vae.to(memory_format=torch.channels_last)
-```
-
-Compile the components and run inference:
-
-```python
-pipe.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
-pipe.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)
-
-# CogVideoX works well with long and well-described prompts
-prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
-video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
-```
-
-The [benchmark](TODO) results on an 80GB A100 machine are:
-
-```
-Without torch.compile(): Average inference time: TODO seconds.
-With torch.compile(): Average inference time: TODO seconds.
-```
-
## CogView3PlusPipeline
[[autodoc]] CogView3PlusPipeline
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index ea809b3c0302..2d2bad7dfab1 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -442,6 +442,60 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
return embeds
+class CogView3PlusPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 16,
+ hidden_size: int = 2560,
+ patch_size: int = 2,
+ text_hidden_size: int = 4096,
+ pos_embed_max_size: int = 128,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_size = hidden_size
+ self.patch_size = patch_size
+ self.text_hidden_size = text_hidden_size
+ self.pos_embed_max_size = pos_embed_max_size
+ # Linear projection for image patches
+ self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
+
+ # Linear projection for text embeddings
+ self.text_proj = nn.Linear(text_hidden_size, hidden_size)
+
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
+ pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
+
+ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, channel, height, width = hidden_states.shape
+
+ if height % self.patch_size != 0 or width % self.patch_size != 0:
+ raise ValueError("Height and width must be divisible by patch size")
+
+ height = height // self.patch_size
+ width = width // self.patch_size
+ hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
+ hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
+ hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
+
+ # Project the patches
+ hidden_states = self.proj(hidden_states)
+ encoder_hidden_states = self.text_proj(encoder_hidden_states)
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ # Calculate text_length
+ text_length = encoder_hidden_states.shape[1]
+
+ image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
+ text_pos_embed = torch.zeros(
+ (text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
+ )
+ pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...]
+
+ return (hidden_states + pos_embed).to(hidden_states.dtype)
+
+
def get_3d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
@@ -714,58 +768,6 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
return freqs_cos, freqs_sin
-class CogView3PlusPatchEmbed(nn.Module):
- def __init__(
- self,
- in_channels: int = 16,
- hidden_size: int = 2560,
- patch_size: int = 2,
- text_hidden_size: int = 4096,
- pos_embed_max_size: int = 128,
- ):
- super().__init__()
- self.in_channels = in_channels
- self.hidden_size = hidden_size
- self.patch_size = patch_size
- self.text_hidden_size = text_hidden_size
- self.pos_embed_max_size = pos_embed_max_size
- # Linear projection for image patches
- self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
-
- # Linear projection for text embeddings
- self.text_proj = nn.Linear(text_hidden_size, hidden_size)
-
- pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
- pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
- self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
-
- def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None) -> torch.Tensor:
- batch_size, channel, height, width = hidden_states.shape
- if height % self.patch_size != 0 or width % self.patch_size != 0:
- raise ValueError("Height and width must be divisible by patch size")
- height = height // self.patch_size
- width = width // self.patch_size
- hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
- hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
- hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
-
- # Project the patches
- hidden_states = self.proj(hidden_states)
- encoder_hidden_states = self.text_proj(encoder_hidden_states)
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
-
- # Calculate text_length
- text_length = encoder_hidden_states.shape[1]
-
- image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
- text_pos_embed = torch.zeros(
- (text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
- )
- pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...]
-
- return (hidden_states + pos_embed).to(hidden_states.dtype)
-
-
class TimestepEmbedding(nn.Module):
def __init__(
self,
@@ -1090,11 +1092,11 @@ def forward(self, timestep, class_labels, hidden_dtype=None):
class CombinedTimestepTextProjEmbeddings(nn.Module):
- def __init__(self, embedding_dim, pooled_projection_dim, timesteps_dim=256):
+ def __init__(self, embedding_dim, pooled_projection_dim):
super().__init__()
- self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
- self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(self, timestep, pooled_projection):
@@ -1132,7 +1134,7 @@ def forward(self, timestep, guidance, pooled_projection):
return conditioning
-class CogView3CombinedTimestepConditionEmbeddings(nn.Module):
+class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
super().__init__()
@@ -1154,9 +1156,11 @@ def forward(
original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1)
crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
+
+ # (B, 3 * condition_dim)
condition_proj = torch.cat(
[original_size_proj, crop_coords_proj, target_size_proj], dim=1
- ) # (B, 3 * condition_dim)
+ )
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index 38ddd6efa011..240ca9aec61f 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -29,7 +29,7 @@
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...utils import is_torch_version, logging
-from ..embeddings import CogView3CombinedTimestepConditionEmbeddings, CogView3PlusPatchEmbed
+from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
@@ -133,12 +133,27 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
The number of channels in each head.
num_attention_heads (`int`, defaults to `64`):
The number of heads to use for multi-head attention.
- out_channels (`int`, *optional*, defaults to `16`):
+ out_channels (`int`, defaults to `16`):
The number of channels in the output.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings.
+ condition_dim (`int`, defaults to `256`):
+ The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size, crop_coords).
+ pooled_projection_dim (`int`, defaults to `1536`):
+ The overall pooled dimension by concatenating SDXL-style resolution conditions. As 3 additional conditions are
+ used (original_size, target_size, crop_coords), and each is a sinusoidal condition of dimension `2 * condition_dim`,
+ we get the pooled projection dimension as `2 * condition_dim * 3 => 1536`. The timestep embeddings will be projected
+ to this dimension as well.
+ TODO(yiyi): Do we need this parameter based on the above explanation?
+ pos_embed_max_size (`int`, defaults to `128`):
+ The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added to input
+ patched latents, where `H` and `W` are the latent height and width respectively. A value of 128 means that the maximum
+ supported height and width for image generation is `128 * vae_scale_factor * patch_size => 128 * 8 * 2 => 2048`.
+ sample_size (`int`, defaults to `128`):
+ The base resolution of input latents. If height/width is not provided during generation, this value is used to determine
+ the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
"""
_supports_gradient_checkpointing = True
@@ -163,7 +178,7 @@ def __init__(
self.out_channels = out_channels
self.inner_dim = num_attention_heads * attention_head_dim
- self.pos_embed = CogView3PlusPatchEmbed(
+ self.patch_embed = CogView3PlusPatchEmbed(
in_channels=in_channels,
hidden_size=self.inner_dim,
patch_size=patch_size,
@@ -171,7 +186,7 @@ def __init__(
pos_embed_max_size=pos_embed_max_size,
)
- self.time_condition_embed = CogView3CombinedTimestepConditionEmbeddings(
+ self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings(
embedding_dim=time_embed_dim,
condition_dim=condition_dim,
pooled_projection_dim=pooled_projection_dim,
@@ -318,20 +333,31 @@ def forward(
The [`CogView3PlusTransformer2DModel`] forward method.
Args:
- hidden_states (`torch.Tensor`): Input `hidden_states`.
- timestep (`torch.LongTensor`): Indicates denoising step.
- y (`torch.LongTensor`, *optional*): 标签输入,用于获取标签嵌入。
- block_controlnet_hidden_states: (`list` of `torch.Tensor`): A list of tensors for residuals.
- joint_attention_kwargs (`dict`, *optional*): Additional kwargs for the attention processor.
- return_dict (`bool`, *optional*, defaults to `True`): Whether to return a `Transformer2DModelOutput`.
+ hidden_states (`torch.Tensor`):
+ Input `hidden_states` of shape `(batch size, channel, height, width)`.
+ encoder_hidden_states (`torch.Tensor`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts)
+ of shape `(batch_size, sequence_len, text_embed_dim)`
+ timestep (`torch.LongTensor`):
+ Used to indicate denoising step.
+ original_size (`torch.Tensor`):
+ CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`torch.Tensor`):
+ CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crop_coords (`torch.Tensor`):
+ CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
Returns:
- Output tensor or `Transformer2DModelOutput`.
+ `torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
+ The denoised latents using provided inputs as conditioning.
"""
height, width = hidden_states.shape[-2:]
text_seq_length = encoder_hidden_states.shape[1]
- hidden_states = self.pos_embed(
+ hidden_states = self.patch_embed(
hidden_states, encoder_hidden_states
) # takes care of adding positional embeddings too.
emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 7ff219efdf5f..1ceca647cd9d 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -146,7 +146,7 @@
"CogVideoXVideoToVideoPipeline",
]
_import_structure["cogview3"] = [
- "CogView3PlusPipeline",
+ "CogView3PlusPipeline"
]
_import_structure["controlnet"].extend(
[
diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
index 8f2452a48805..6975b9181e30 100644
--- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
+++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
@@ -229,7 +229,7 @@ def encode_prompt(
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
- max_sequence_length: int = 226,
+ max_sequence_length: int = 224,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
@@ -254,6 +254,8 @@ def encode_prompt(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
+ max_sequence_length (`int`, defaults to `224`):
+ Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
@@ -430,7 +432,7 @@ def __call__(
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
- guidance_scale: float = 6,
+ guidance_scale: float = 5.0,
use_dynamic_cfg: bool = False,
num_images_per_prompt: int = 1,
eta: float = 0.0,
@@ -440,10 +442,6 @@ def __call__(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
- target_size: Optional[Tuple[int, int]] = None,
- negative_original_size: Optional[Tuple[int, int]] = None,
- negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
- negative_target_size: Optional[Tuple[int, int]] = None,
output_type: str = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
@@ -473,7 +471,7 @@ def __call__(
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
- guidance_scale (`float`, *optional*, defaults to 7.0):
+ guidance_scale (`float`, *optional*, defaults to `5.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
@@ -505,25 +503,6 @@ def __call__(
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
- target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
- For most cases, `target_size` should be set to the desired height and width of the generated image. If
- not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
- section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
- negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
- To negatively condition the generation process based on a specific image resolution. Part of SDXL's
- micro-conditioning as explained in section 2.2 of
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
- information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
- negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
- To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
- micro-conditioning as explained in section 2.2 of
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
- information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
- negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
- To negatively condition the generation process based on a target image resolution. It should be as same
- as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
- information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -561,7 +540,7 @@ def __call__(
width = width or self.transformer.config.sample_size * self.vae_scale_factor
original_size = original_size or (height, width)
- target_size = target_size or (height, width)
+ target_size = (height, width)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
@@ -626,24 +605,14 @@ def __call__(
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Prepare additional timestep conditions
- # TODO: Make this like SDXL
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype)
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype)
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype)
- if negative_original_size is not None and negative_target_size is not None:
- negative_original_size = torch.tensor([negative_original_size], dtype=prompt_embeds.dtype)
- negative_target_size = torch.tensor([negative_target_size], dtype=prompt_embeds.dtype)
- negative_crops_coords_top_left = torch.tensor([negative_crops_coords_top_left], dtype=prompt_embeds.dtype)
- else:
- negative_original_size = original_size
- negative_target_size = target_size
- negative_crops_coords_top_left = crops_coords_top_left
-
if do_classifier_free_guidance:
- original_size = torch.cat([negative_original_size, original_size])
- target_size = torch.cat([negative_target_size, target_size])
- crops_coords_top_left = torch.cat([negative_crops_coords_top_left, crops_coords_top_left])
+ original_size = torch.cat([original_size, original_size])
+ target_size = torch.cat([target_size, target_size])
+ crops_coords_top_left = torch.cat([crops_coords_top_left, crops_coords_top_left])
original_size = original_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
From b439f4c82f246713d5e312c23078705d060516db Mon Sep 17 00:00:00 2001
From: Aryan
Date: Wed, 9 Oct 2024 15:35:52 +0200
Subject: [PATCH 19/30] make style
---
scripts/convert_cogview3_to_diffusers.py | 33 +++++++----------
src/diffusers/models/embeddings.py | 10 ++---
.../transformers/transformer_cogview3plus.py | 37 +++++++++++--------
src/diffusers/pipelines/__init__.py | 4 +-
4 files changed, 39 insertions(+), 45 deletions(-)
diff --git a/scripts/convert_cogview3_to_diffusers.py b/scripts/convert_cogview3_to_diffusers.py
index 632f5c4bc7f4..3e281bba8b47 100644
--- a/scripts/convert_cogview3_to_diffusers.py
+++ b/scripts/convert_cogview3_to_diffusers.py
@@ -6,26 +6,19 @@
Example usage:
python scripts/convert_cogview3_to_diffusers.py \
- --original_state_dict_repo_id "THUDM/cogview3-sat" \
- --filename "cogview3.pt" \
- --transformer \
- --output_path "./cogview3_diffusers" \
- --dtype "bf16"
-
-Alternatively, if you have a local checkpoint:
- python scripts/convert_cogview3_to_diffusers.py \
- --checkpoint_path 'your path/cogview3plus_3b/1/mp_rank_00_model_states.pt' \
- --transformer \
+ --transformer_checkpoint_path 'your path/cogview3plus_3b/1/mp_rank_00_model_states.pt' \
+ --vae_checkpoint_path 'your path/3plus_ae/imagekl_ch16.pt' \
--output_path "/raid/yiyi/cogview3_diffusers" \
--dtype "bf16"
Arguments:
- --original_state_dict_repo_id: The Hugging Face repo ID containing the original checkpoint.
- --filename: The filename of the checkpoint in the repo (default: "flux.safetensors").
- --checkpoint_path: Path to a local checkpoint file (alternative to repo_id and filename).
- --transformer: Flag to convert the transformer model.
+ --transformer_checkpoint_path: Path to Transformer state dict.
+ --vae_checkpoint_path: Path to VAE state dict.
--output_path: The path to save the converted model.
- --dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32").
+ --push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
+ --text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used
+ --dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
+
Default is "bf16" because CogView3 uses bfloat16 for Training.
Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
@@ -73,11 +66,11 @@ def convert_cogview3_transformer_checkpoint_to_diffusers(ckpt_path):
new_state_dict = {}
- # Convert pos_embed
- new_state_dict["pos_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
- new_state_dict["pos_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
- new_state_dict["pos_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
- new_state_dict["pos_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
+ # Convert patch_embed
+ new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
+ new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
+ new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
+ new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
# Convert time_condition_embed
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 2d2bad7dfab1..44f01c46ebe8 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -469,10 +469,10 @@ def __init__(
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, channel, height, width = hidden_states.shape
-
+
if height % self.patch_size != 0 or width % self.patch_size != 0:
raise ValueError("Height and width must be divisible by patch size")
-
+
height = height // self.patch_size
width = width // self.patch_size
hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
@@ -1156,11 +1156,9 @@ def forward(
original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1)
crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
-
+
# (B, 3 * condition_dim)
- condition_proj = torch.cat(
- [original_size_proj, crop_coords_proj, target_size_proj], dim=1
- )
+ condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index 240ca9aec61f..9fcbaf54d57a 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -140,20 +140,22 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings.
condition_dim (`int`, defaults to `256`):
- The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size, crop_coords).
+ The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
+ crop_coords).
pooled_projection_dim (`int`, defaults to `1536`):
- The overall pooled dimension by concatenating SDXL-style resolution conditions. As 3 additional conditions are
- used (original_size, target_size, crop_coords), and each is a sinusoidal condition of dimension `2 * condition_dim`,
- we get the pooled projection dimension as `2 * condition_dim * 3 => 1536`. The timestep embeddings will be projected
- to this dimension as well.
- TODO(yiyi): Do we need this parameter based on the above explanation?
+ The overall pooled dimension by concatenating SDXL-style resolution conditions. As 3 additional conditions
+ are used (original_size, target_size, crop_coords), and each is a sinusoidal condition of dimension `2 *
+ condition_dim`, we get the pooled projection dimension as `2 * condition_dim * 3 => 1536`. The timestep
+ embeddings will be projected to this dimension as well. TODO(yiyi): Do we need this parameter based on the
+ above explanation?
pos_embed_max_size (`int`, defaults to `128`):
- The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added to input
- patched latents, where `H` and `W` are the latent height and width respectively. A value of 128 means that the maximum
- supported height and width for image generation is `128 * vae_scale_factor * patch_size => 128 * 8 * 2 => 2048`.
+ The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
+ to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
+ means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
+ patch_size => 128 * 8 * 2 => 2048`.
sample_size (`int`, defaults to `128`):
- The base resolution of input latents. If height/width is not provided during generation, this value is used to determine
- the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
+ The base resolution of input latents. If height/width is not provided during generation, this value is used
+ to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
"""
_supports_gradient_checkpointing = True
@@ -336,16 +338,19 @@ def forward(
hidden_states (`torch.Tensor`):
Input `hidden_states` of shape `(batch size, channel, height, width)`.
encoder_hidden_states (`torch.Tensor`):
- Conditional embeddings (embeddings computed from the input conditions such as prompts)
- of shape `(batch_size, sequence_len, text_embed_dim)`
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) of shape
+ `(batch_size, sequence_len, text_embed_dim)`
timestep (`torch.LongTensor`):
Used to indicate denoising step.
original_size (`torch.Tensor`):
- CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
target_size (`torch.Tensor`):
- CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
crop_coords (`torch.Tensor`):
- CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 1ceca647cd9d..f05e349ea187 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -145,9 +145,7 @@
"CogVideoXImageToVideoPipeline",
"CogVideoXVideoToVideoPipeline",
]
- _import_structure["cogview3"] = [
- "CogView3PlusPipeline"
- ]
+ _import_structure["cogview3"] = ["CogView3PlusPipeline"]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
From 9d9b0b286f73db6ba3cc9474551a8f7d84a27578 Mon Sep 17 00:00:00 2001
From: Aryan
Date: Wed, 9 Oct 2024 16:29:37 +0200
Subject: [PATCH 20/30] make model work
---
scripts/convert_cogview3_to_diffusers.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/scripts/convert_cogview3_to_diffusers.py b/scripts/convert_cogview3_to_diffusers.py
index 3e281bba8b47..481357adcea0 100644
--- a/scripts/convert_cogview3_to_diffusers.py
+++ b/scripts/convert_cogview3_to_diffusers.py
@@ -192,7 +192,7 @@ def main(args):
"latent_channels": 16,
"norm_num_groups": 32,
"sample_size": 1024,
- "scaling_factor": 0.18215,
+ "scaling_factor": 1.0,
"force_upcast": True,
"use_quant_conv": False,
"use_post_quant_conv": False,
From 2158f00e3f2ee03c556c4dc17540f4d19e321631 Mon Sep 17 00:00:00 2001
From: Aryan
Date: Wed, 9 Oct 2024 16:30:35 +0200
Subject: [PATCH 21/30] remove qkv fusion
---
.../transformers/transformer_cogview3plus.py | 41 -------------------
.../cogview3/pipeline_cogview3plus.py | 13 ------
2 files changed, 54 deletions(-)
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index 9fcbaf54d57a..79390f68ce84 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -24,7 +24,6 @@
Attention,
AttentionProcessor,
CogVideoXAttnProcessor2_0,
- FusedCogVideoXAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
@@ -277,46 +276,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
- def fuse_qkv_projections(self):
- """
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
- are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
-
- This API is 🧪 experimental.
-
-
- """
- self.original_attn_processors = None
-
- for _, attn_processor in self.attn_processors.items():
- if "Added" in str(attn_processor.__class__.__name__):
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
-
- self.original_attn_processors = self.attn_processors
-
- for module in self.modules():
- if isinstance(module, Attention):
- module.fuse_projections(fuse=True)
-
- self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
- def unfuse_qkv_projections(self):
- """Disables the fused QKV projection if enabled.
-
-
-
- This API is 🧪 experimental.
-
-
-
- """
- if self.original_attn_processors is not None:
- self.set_attn_processor(self.original_attn_processors)
-
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
index 6975b9181e30..e10da74bb554 100644
--- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
+++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
@@ -397,19 +397,6 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)
- def fuse_qkv_projections(self) -> None:
- r"""Enables fused QKV projections."""
- self.fusing_transformer = True
- self.transformer.fuse_qkv_projections()
-
- def unfuse_qkv_projections(self) -> None:
- r"""Disable QKV projection fusion if enabled."""
- if not self.fusing_transformer:
- logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
- else:
- self.transformer.unfuse_qkv_projections()
- self.fusing_transformer = False
-
@property
def guidance_scale(self):
return self._guidance_scale
From 1b060207cfd2d0870a4dafbf7a6e16cec9bb7e96 Mon Sep 17 00:00:00 2001
From: Aryan
Date: Wed, 9 Oct 2024 16:33:11 +0200
Subject: [PATCH 22/30] remove qkv fusion tets
---
tests/pipelines/cogview3/test_cogview3plus.py | 40 -------------------
1 file changed, 40 deletions(-)
diff --git a/tests/pipelines/cogview3/test_cogview3plus.py b/tests/pipelines/cogview3/test_cogview3plus.py
index 433bdbd90728..ee2448800b30 100644
--- a/tests/pipelines/cogview3/test_cogview3plus.py
+++ b/tests/pipelines/cogview3/test_cogview3plus.py
@@ -32,8 +32,6 @@
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
- check_qkv_fusion_matches_attn_procs_length,
- check_qkv_fusion_processors_exist,
to_np,
)
@@ -233,44 +231,6 @@ def test_attention_slicing_forward_pass(
"Attention slicing should not affect the inference results",
)
- def test_fused_qkv_projections(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- images = pipe(**inputs)[0] # [B, C, H, W]
- original_image_slice = images[0, -1, -3:, -3:]
-
- pipe.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- assert check_qkv_fusion_matches_attn_procs_length(
- pipe.transformer, pipe.transformer.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
-
- inputs = self.get_dummy_inputs(device)
- images = pipe(**inputs)[0]
- image_slice_fused = images[0, -1, -3:, -3:]
-
- pipe.transformer.unfuse_qkv_projections()
- inputs = self.get_dummy_inputs(device)
- images = pipe(**inputs)[0]
- image_slice_disabled = images[0, -1, -3:, -3:]
-
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
-
@slow
@require_torch_gpu
From 80e7cca64f972fdc99fd5b477068cfe0acdc67ab Mon Sep 17 00:00:00 2001
From: Aryan
Date: Fri, 11 Oct 2024 01:10:32 +0200
Subject: [PATCH 23/30] address review comments
---
.../cogview3/pipeline_cogview3plus.py | 24 ++++---------------
1 file changed, 4 insertions(+), 20 deletions(-)
diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
index e10da74bb554..88d1dd1b2c69 100644
--- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
+++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
@@ -49,25 +49,6 @@
"""
-# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
-def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
- tw = tgt_width
- th = tgt_height
- h, w = src
- r = h / w
- if r > (th / tw):
- resize_height = th
- resize_width = int(round(th / h * w))
- else:
- resize_width = tw
- resize_height = int(round(tw / w * h))
-
- crop_top = int(round((th - resize_height) / 2.0))
- crop_left = int(round((tw - resize_width) / 2.0))
-
- return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
-
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
@@ -179,6 +160,7 @@ def __init__(
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
@@ -278,8 +260,10 @@ def encode_prompt(
dtype=dtype,
)
+ if do_classifier_free_guidance and negative_prompt is None:
+ negative_prompt_embeds = prompt_embeds.new_zeros(prompt_embeds.shape)
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
- negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
From 0e4577df71065205e666bcea39eb1410ecd570dd Mon Sep 17 00:00:00 2001
From: Aryan
Date: Fri, 11 Oct 2024 12:58:56 +0200
Subject: [PATCH 24/30] fix make fix-copies error
---
src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
index 88d1dd1b2c69..8080fa95c395 100644
--- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
+++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
@@ -160,7 +160,7 @@ def __init__(
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
- # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds with num_videos_per_prompt->num_images_per_prompt
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
From 9c3a81dedf8841391107599b54d56462e81d97db Mon Sep 17 00:00:00 2001
From: zR <2448370773@qq.com>
Date: Sat, 12 Oct 2024 00:43:56 +0800
Subject: [PATCH 25/30] remove None and TODO
---
scripts/convert_cogview3_to_diffusers.py | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/scripts/convert_cogview3_to_diffusers.py b/scripts/convert_cogview3_to_diffusers.py
index 481357adcea0..259380fd6388 100644
--- a/scripts/convert_cogview3_to_diffusers.py
+++ b/scripts/convert_cogview3_to_diffusers.py
@@ -156,8 +156,7 @@ def convert_cogview3_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
def main(args):
- if args.dtype is None:
- dtype = None
+
if args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
@@ -212,7 +211,6 @@ def main(args):
for param in text_encoder.parameters():
param.data = param.data.contiguous()
- # TODO: figure out the correct scheduler if it is same as CogVideoXDDIMScheduler
scheduler = CogVideoXDDIMScheduler.from_config(
{
"snr_shift_scale": 4.0,
From 6603901dc33e980c121c2b66dc62a831c8920907 Mon Sep 17 00:00:00 2001
From: zR <2448370773@qq.com>
Date: Sat, 12 Oct 2024 01:30:47 +0800
Subject: [PATCH 26/30] for FP16(draft)
---
src/diffusers/models/transformers/transformer_cogview3plus.py | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index 79390f68ce84..27033c998537 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -113,6 +113,10 @@ def forward(
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+ if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return hidden_states, encoder_hidden_states
From db2a958c9036001079d1db0df05b63c95fedd73b Mon Sep 17 00:00:00 2001
From: Aryan
Date: Fri, 11 Oct 2024 22:51:48 +0200
Subject: [PATCH 27/30] make style
---
scripts/convert_cogview3_to_diffusers.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/scripts/convert_cogview3_to_diffusers.py b/scripts/convert_cogview3_to_diffusers.py
index 259380fd6388..48cda2084240 100644
--- a/scripts/convert_cogview3_to_diffusers.py
+++ b/scripts/convert_cogview3_to_diffusers.py
@@ -156,7 +156,6 @@ def convert_cogview3_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
def main(args):
-
if args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
From 270d407eac081a2692d517a8245dd7f35235d85b Mon Sep 17 00:00:00 2001
From: Aryan
Date: Mon, 14 Oct 2024 10:04:14 +0200
Subject: [PATCH 28/30] remove dynamic cfg
---
.../cogview3/pipeline_cogview3plus.py | 23 ++++++++++---------
1 file changed, 12 insertions(+), 11 deletions(-)
diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
index 8080fa95c395..7ae86421c45e 100644
--- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
+++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
@@ -14,7 +14,6 @@
# limitations under the License.
import inspect
-import math
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
@@ -385,6 +384,13 @@ def check_inputs(
def guidance_scale(self):
return self._guidance_scale
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
@property
def num_timesteps(self):
return self._num_timesteps
@@ -404,7 +410,6 @@ def __call__(
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 5.0,
- use_dynamic_cfg: bool = False,
num_images_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -545,14 +550,14 @@ def __call__(
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
negative_prompt,
- do_classifier_free_guidance,
+ self.do_classifier_free_guidance,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
- if do_classifier_free_guidance:
+ if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
@@ -580,7 +585,7 @@ def __call__(
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype)
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype)
- if do_classifier_free_guidance:
+ if self.do_classifier_free_guidance:
original_size = torch.cat([original_size, original_size])
target_size = torch.cat([target_size, target_size])
crops_coords_top_left = torch.cat([crops_coords_top_left, crops_coords_top_left])
@@ -599,7 +604,7 @@ def __call__(
if self.interrupt:
continue
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@@ -618,11 +623,7 @@ def __call__(
noise_pred = noise_pred.float()
# perform guidance
- if use_dynamic_cfg:
- self._guidance_scale = 1 + guidance_scale * (
- (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
- )
- if do_classifier_free_guidance:
+ if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
From 21dd8900f35a592be070a0aa7d69321e7e888730 Mon Sep 17 00:00:00 2001
From: Aryan
Date: Mon, 14 Oct 2024 10:37:28 +0200
Subject: [PATCH 29/30] remove pooled_projection_dim as a parameter
---
.../transformers/transformer_cogview3plus.py | 27 ++++++++++++-------
1 file changed, 17 insertions(+), 10 deletions(-)
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index 27033c998537..962cbbff7c1b 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -37,8 +37,18 @@
class CogView3PlusTransformerBlock(nn.Module):
- """
- Updated CogView3 Transformer Block to align with AdalnAttentionMixin style, simplified with qk_ln always True.
+ r"""
+ Transformer block used in [CogView](https://github.com/THUDM/CogView3) model.
+
+ Args:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ time_embed_dim (`int`):
+ The number of channels in timestep embedding.
"""
def __init__(
@@ -145,12 +155,6 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
condition_dim (`int`, defaults to `256`):
The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
crop_coords).
- pooled_projection_dim (`int`, defaults to `1536`):
- The overall pooled dimension by concatenating SDXL-style resolution conditions. As 3 additional conditions
- are used (original_size, target_size, crop_coords), and each is a sinusoidal condition of dimension `2 *
- condition_dim`, we get the pooled projection dimension as `2 * condition_dim * 3 => 1536`. The timestep
- embeddings will be projected to this dimension as well. TODO(yiyi): Do we need this parameter based on the
- above explanation?
pos_embed_max_size (`int`, defaults to `128`):
The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
@@ -175,7 +179,6 @@ def __init__(
text_embed_dim: int = 4096,
time_embed_dim: int = 512,
condition_dim: int = 256,
- pooled_projection_dim: int = 1536,
pos_embed_max_size: int = 128,
sample_size: int = 128,
):
@@ -183,6 +186,10 @@ def __init__(
self.out_channels = out_channels
self.inner_dim = num_attention_heads * attention_head_dim
+ # CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords
+ # Each of these are sincos embeddings of shape 2 * condition_dim
+ self.pooled_projection_dim = 3 * 2 * condition_dim
+
self.patch_embed = CogView3PlusPatchEmbed(
in_channels=in_channels,
hidden_size=self.inner_dim,
@@ -194,7 +201,7 @@ def __init__(
self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings(
embedding_dim=time_embed_dim,
condition_dim=condition_dim,
- pooled_projection_dim=pooled_projection_dim,
+ pooled_projection_dim=self.pooled_projection_dim,
timesteps_dim=self.inner_dim,
)
From 221b486cd04187e8a9471947aa4d92398cdb4e0c Mon Sep 17 00:00:00 2001
From: Aryan
Date: Mon, 14 Oct 2024 14:10:08 +0200
Subject: [PATCH 30/30] fix tests
---
.../models/transformers/test_models_transformer_cogview3plus.py | 1 -
tests/pipelines/cogview3/test_cogview3plus.py | 1 -
2 files changed, 2 deletions(-)
diff --git a/tests/models/transformers/test_models_transformer_cogview3plus.py b/tests/models/transformers/test_models_transformer_cogview3plus.py
index 0212b95ba130..46612dbd9190 100644
--- a/tests/models/transformers/test_models_transformer_cogview3plus.py
+++ b/tests/models/transformers/test_models_transformer_cogview3plus.py
@@ -78,7 +78,6 @@ def prepare_init_args_and_inputs_for_common(self):
"text_embed_dim": 8,
"time_embed_dim": 8,
"condition_dim": 2,
- "pooled_projection_dim": 12,
"pos_embed_max_size": 8,
"sample_size": 8,
}
diff --git a/tests/pipelines/cogview3/test_cogview3plus.py b/tests/pipelines/cogview3/test_cogview3plus.py
index ee2448800b30..8d56552ba5ee 100644
--- a/tests/pipelines/cogview3/test_cogview3plus.py
+++ b/tests/pipelines/cogview3/test_cogview3plus.py
@@ -69,7 +69,6 @@ def get_dummy_components(self):
text_embed_dim=32, # Must match with tiny-random-t5
time_embed_dim=8,
condition_dim=2,
- pooled_projection_dim=12,
pos_embed_max_size=8,
sample_size=8,
)