diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index b331e4b13760..22613cb343ff 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -242,6 +242,8 @@ title: AuraFlowTransformer2DModel - local: api/models/cogvideox_transformer3d title: CogVideoXTransformer3DModel + - local: api/models/cogview3plus_transformer2d + title: CogView3PlusTransformer2DModel - local: api/models/dit_transformer2d title: DiTTransformer2DModel - local: api/models/flux_transformer @@ -320,6 +322,8 @@ title: BLIP-Diffusion - local: api/pipelines/cogvideox title: CogVideoX + - local: api/pipelines/cogview3 + title: CogView3 - local: api/pipelines/consistency_models title: Consistency Models - local: api/pipelines/controlnet diff --git a/docs/source/en/api/models/cogview3plus_transformer2d.md b/docs/source/en/api/models/cogview3plus_transformer2d.md new file mode 100644 index 000000000000..16f71a58cfb4 --- /dev/null +++ b/docs/source/en/api/models/cogview3plus_transformer2d.md @@ -0,0 +1,30 @@ + + +# CogView3PlusTransformer2DModel + +A Diffusion Transformer model for 2D data from [CogView3Plus](https://github.com/THUDM/CogView3) was introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) by Tsinghua University & ZhipuAI. + +The model can be loaded with the following code snippet. + +```python +from diffusers import CogView3PlusTransformer2DModel + +vae = CogView3PlusTransformer2DModel.from_pretrained("THUDM/CogView3Plus-3b", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") +``` + +## CogView3PlusTransformer2DModel + +[[autodoc]] CogView3PlusTransformer2DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/cogview3.md b/docs/source/en/api/pipelines/cogview3.md new file mode 100644 index 000000000000..85a9cf91736f --- /dev/null +++ b/docs/source/en/api/pipelines/cogview3.md @@ -0,0 +1,40 @@ + + +# CogView3Plus + +[CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) from Tsinghua University & ZhipuAI, by Wendi Zheng, Jiayan Teng, Zhuoyi Yang, Weihan Wang, Jidong Chen, Xiaotao Gu, Yuxiao Dong, Ming Ding, Jie Tang. + +The abstract from the paper is: + +*Recent advancements in text-to-image generative systems have been largely driven by diffusion models. However, single-stage text-to-image diffusion models still face challenges, in terms of computational efficiency and the refinement of image details. To tackle the issue, we propose CogView3, an innovative cascaded framework that enhances the performance of text-to-image diffusion. CogView3 is the first model implementing relay diffusion in the realm of text-to-image generation, executing the task by first creating low-resolution images and subsequently applying relay-based super-resolution. This methodology not only results in competitive text-to-image outputs but also greatly reduces both training and inference costs. Our experimental results demonstrate that CogView3 outperforms SDXL, the current state-of-the-art open-source text-to-image diffusion model, by 77.0% in human evaluations, all while requiring only about 1/2 of the inference time. The distilled variant of CogView3 achieves comparable performance while only utilizing 1/10 of the inference time by SDXL.* + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM). + +## CogView3PlusPipeline + +[[autodoc]] CogView3PlusPipeline + - all + - __call__ + +## CogView3PipelineOutput + +[[autodoc]] pipelines.cogview3.pipeline_output.CogView3PipelineOutput diff --git a/scripts/convert_cogview3_to_diffusers.py b/scripts/convert_cogview3_to_diffusers.py new file mode 100644 index 000000000000..48cda2084240 --- /dev/null +++ b/scripts/convert_cogview3_to_diffusers.py @@ -0,0 +1,242 @@ +""" +Convert a CogView3 checkpoint to the Diffusers format. + +This script converts a CogView3 checkpoint to the Diffusers format, which can then be used +with the Diffusers library. + +Example usage: + python scripts/convert_cogview3_to_diffusers.py \ + --transformer_checkpoint_path 'your path/cogview3plus_3b/1/mp_rank_00_model_states.pt' \ + --vae_checkpoint_path 'your path/3plus_ae/imagekl_ch16.pt' \ + --output_path "/raid/yiyi/cogview3_diffusers" \ + --dtype "bf16" + +Arguments: + --transformer_checkpoint_path: Path to Transformer state dict. + --vae_checkpoint_path: Path to VAE state dict. + --output_path: The path to save the converted model. + --push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`. + --text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used + --dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered. + + Default is "bf16" because CogView3 uses bfloat16 for Training. + +Note: You must provide either --original_state_dict_repo_id or --checkpoint_path. +""" + +import argparse +from contextlib import nullcontext + +import torch +from accelerate import init_empty_weights +from transformers import T5EncoderModel, T5Tokenizer + +from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel +from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint +from diffusers.utils.import_utils import is_accelerate_available + + +CTX = init_empty_weights if is_accelerate_available else nullcontext + +TOKENIZER_MAX_LENGTH = 224 + +parser = argparse.ArgumentParser() +parser.add_argument("--transformer_checkpoint_path", default=None, type=str) +parser.add_argument("--vae_checkpoint_path", default=None, type=str) +parser.add_argument("--output_path", required=True, type=str) +parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving") +parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory") +parser.add_argument("--dtype", type=str, default="bf16") + +args = parser.parse_args() + + +# this is specific to `AdaLayerNormContinuous`: +# diffusers implementation split the linear projection into the scale, shift while CogView3 split it tino shift, scale +def swap_scale_shift(weight, dim): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def convert_cogview3_transformer_checkpoint_to_diffusers(ckpt_path): + original_state_dict = torch.load(ckpt_path, map_location="cpu") + original_state_dict = original_state_dict["module"] + original_state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in original_state_dict.items()} + + new_state_dict = {} + + # Convert patch_embed + new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight") + new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias") + new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight") + new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias") + + # Convert time_condition_embed + new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop( + "time_embed.0.weight" + ) + new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop( + "time_embed.0.bias" + ) + new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop( + "time_embed.2.weight" + ) + new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop( + "time_embed.2.bias" + ) + new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop( + "label_emb.0.0.weight" + ) + new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop( + "label_emb.0.0.bias" + ) + new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop( + "label_emb.0.2.weight" + ) + new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop( + "label_emb.0.2.bias" + ) + + # Convert transformer blocks + for i in range(30): + block_prefix = f"transformer_blocks.{i}." + old_prefix = f"transformer.layers.{i}." + adaln_prefix = f"mixins.adaln.adaln_modules.{i}." + + new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight") + new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias") + + qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight") + qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias") + q, k, v = qkv_weight.chunk(3, dim=0) + q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0) + + new_state_dict[block_prefix + "attn1.to_q.weight"] = q + new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias + new_state_dict[block_prefix + "attn1.to_k.weight"] = k + new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias + new_state_dict[block_prefix + "attn1.to_v.weight"] = v + new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias + + new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop( + old_prefix + "attention.dense.weight" + ) + new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop( + old_prefix + "attention.dense.bias" + ) + + new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop( + old_prefix + "mlp.dense_h_to_4h.weight" + ) + new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop( + old_prefix + "mlp.dense_h_to_4h.bias" + ) + new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop( + old_prefix + "mlp.dense_4h_to_h.weight" + ) + new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias") + + # Convert final norm and projection + new_state_dict["norm_out.linear.weight"] = swap_scale_shift( + original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0 + ) + new_state_dict["norm_out.linear.bias"] = swap_scale_shift( + original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0 + ) + new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight") + new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias") + + return new_state_dict + + +def convert_cogview3_vae_checkpoint_to_diffusers(ckpt_path, vae_config): + original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] + return convert_ldm_vae_checkpoint(original_state_dict, vae_config) + + +def main(args): + if args.dtype == "fp16": + dtype = torch.float16 + elif args.dtype == "bf16": + dtype = torch.bfloat16 + elif args.dtype == "fp32": + dtype = torch.float32 + else: + raise ValueError(f"Unsupported dtype: {args.dtype}") + + transformer = None + vae = None + + if args.transformer_checkpoint_path is not None: + converted_transformer_state_dict = convert_cogview3_transformer_checkpoint_to_diffusers( + args.transformer_checkpoint_path + ) + transformer = CogView3PlusTransformer2DModel() + transformer.load_state_dict(converted_transformer_state_dict, strict=True) + if dtype is not None: + # Original checkpoint data type will be preserved + transformer = transformer.to(dtype=dtype) + + if args.vae_checkpoint_path is not None: + vae_config = { + "in_channels": 3, + "out_channels": 3, + "down_block_types": ("DownEncoderBlock2D",) * 4, + "up_block_types": ("UpDecoderBlock2D",) * 4, + "block_out_channels": (128, 512, 1024, 1024), + "layers_per_block": 3, + "act_fn": "silu", + "latent_channels": 16, + "norm_num_groups": 32, + "sample_size": 1024, + "scaling_factor": 1.0, + "force_upcast": True, + "use_quant_conv": False, + "use_post_quant_conv": False, + "mid_block_add_attention": False, + } + converted_vae_state_dict = convert_cogview3_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config) + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_state_dict, strict=True) + if dtype is not None: + vae = vae.to(dtype=dtype) + + text_encoder_id = "google/t5-v1_1-xxl" + tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) + text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) + + # Apparently, the conversion does not work anymore without this :shrug: + for param in text_encoder.parameters(): + param.data = param.data.contiguous() + + scheduler = CogVideoXDDIMScheduler.from_config( + { + "snr_shift_scale": 4.0, + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "clip_sample": False, + "num_train_timesteps": 1000, + "prediction_type": "v_prediction", + "rescale_betas_zero_snr": True, + "set_alpha_to_one": True, + "timestep_spacing": "trailing", + } + ) + + pipe = CogView3PlusPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + ) + + # This is necessary for users with insufficient memory, such as those using Colab and notebooks, as it can + # save some memory used for model loading. + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub) + + +if __name__ == "__main__": + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index fadb234c6e10..019d744730ab 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -84,6 +84,7 @@ "AutoencoderOobleck", "AutoencoderTiny", "CogVideoXTransformer3DModel", + "CogView3PlusTransformer2DModel", "ConsistencyDecoderVAE", "ControlNetModel", "ControlNetXSAdapter", @@ -258,6 +259,7 @@ "CogVideoXImageToVideoPipeline", "CogVideoXPipeline", "CogVideoXVideoToVideoPipeline", + "CogView3PlusPipeline", "CycleDiffusionPipeline", "FluxControlNetImg2ImgPipeline", "FluxControlNetInpaintPipeline", @@ -559,6 +561,7 @@ AutoencoderOobleck, AutoencoderTiny, CogVideoXTransformer3DModel, + CogView3PlusTransformer2DModel, ConsistencyDecoderVAE, ControlNetModel, ControlNetXSAdapter, @@ -711,6 +714,7 @@ CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline, + CogView3PlusPipeline, CycleDiffusionPipeline, FluxControlNetImg2ImgPipeline, FluxControlNetInpaintPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index f0dd7248c117..4dda8c36ba1c 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -54,6 +54,7 @@ _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] + _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] @@ -98,6 +99,7 @@ from .transformers import ( AuraFlowTransformer2DModel, CogVideoXTransformer3DModel, + CogView3PlusTransformer2DModel, DiTTransformer2DModel, DualTransformer2DModel, FluxTransformer2DModel, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9f9bc5a46e10..d333590982e3 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -122,6 +122,7 @@ def __init__( out_dim: int = None, context_pre_only=None, pre_only=False, + elementwise_affine: bool = True, ): super().__init__() @@ -179,8 +180,8 @@ def __init__( self.norm_q = None self.norm_k = None elif qk_norm == "layer_norm": - self.norm_q = nn.LayerNorm(dim_head, eps=eps) - self.norm_k = nn.LayerNorm(dim_head, eps=eps) + self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) elif qk_norm == "fp32_layer_norm": self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index c250df29afbe..44f01c46ebe8 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -442,6 +442,60 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): return embeds +class CogView3PlusPatchEmbed(nn.Module): + def __init__( + self, + in_channels: int = 16, + hidden_size: int = 2560, + patch_size: int = 2, + text_hidden_size: int = 4096, + pos_embed_max_size: int = 128, + ): + super().__init__() + self.in_channels = in_channels + self.hidden_size = hidden_size + self.patch_size = patch_size + self.text_hidden_size = text_hidden_size + self.pos_embed_max_size = pos_embed_max_size + # Linear projection for image patches + self.proj = nn.Linear(in_channels * patch_size**2, hidden_size) + + # Linear projection for text embeddings + self.text_proj = nn.Linear(text_hidden_size, hidden_size) + + pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size) + pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False) + + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, channel, height, width = hidden_states.shape + + if height % self.patch_size != 0 or width % self.patch_size != 0: + raise ValueError("Height and width must be divisible by patch size") + + height = height // self.patch_size + width = width // self.patch_size + hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous() + hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size) + + # Project the patches + hidden_states = self.proj(hidden_states) + encoder_hidden_states = self.text_proj(encoder_hidden_states) + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # Calculate text_length + text_length = encoder_hidden_states.shape[1] + + image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1) + text_pos_embed = torch.zeros( + (text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device + ) + pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...] + + return (hidden_states + pos_embed).to(hidden_states.dtype) + + def get_3d_rotary_pos_embed( embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: @@ -1080,6 +1134,39 @@ def forward(self, timestep, guidance, pooled_projection): return conditioning +class CogView3CombinedTimestepSizeEmbeddings(nn.Module): + def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256): + super().__init__() + + self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim) + self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward( + self, + timestep: torch.Tensor, + original_size: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + hidden_dtype: torch.dtype, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + + original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1) + crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1) + target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1) + + # (B, 3 * condition_dim) + condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1) + + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) + condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) + + conditioning = timesteps_emb + condition_emb + return conditioning + + class HunyuanDiTAttentionPool(nn.Module): # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 5740fed9f30c..21e9d3cd6fc5 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -355,6 +355,51 @@ def forward( return x +class CogView3PlusAdaLayerNormZeroTextImage(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, dim: int): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True) + self.norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + + def forward( + self, + x: torch.Tensor, + context: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + c_shift_msa, + c_scale_msa, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = emb.chunk(12, dim=1) + normed_x = self.norm_x(x) + normed_context = self.norm_c(context) + x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None] + context = normed_context * (1 + c_scale_msa[:, None]) + c_shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, context, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp + + class CogVideoXLayerNormZero(nn.Module): def __init__( self, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index d55dfe57d6f3..58787c079ea8 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -14,6 +14,7 @@ from .stable_audio_transformer import StableAudioDiTModel from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel + from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_flux import FluxTransformer2DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_temporal import TransformerTemporalModel diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py new file mode 100644 index 000000000000..962cbbff7c1b --- /dev/null +++ b/src/diffusers/models/transformers/transformer_cogview3plus.py @@ -0,0 +1,386 @@ +# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.attention import FeedForward +from ...models.attention_processor import ( + Attention, + AttentionProcessor, + CogVideoXAttnProcessor2_0, +) +from ...models.modeling_utils import ModelMixin +from ...models.normalization import AdaLayerNormContinuous +from ...utils import is_torch_version, logging +from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed +from ..modeling_outputs import Transformer2DModelOutput +from ..normalization import CogView3PlusAdaLayerNormZeroTextImage + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CogView3PlusTransformerBlock(nn.Module): + r""" + Transformer block used in [CogView](https://github.com/THUDM/CogView3) model. + + Args: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + time_embed_dim (`int`): + The number of channels in timestep embedding. + """ + + def __init__( + self, + dim: int = 2560, + num_attention_heads: int = 64, + attention_head_dim: int = 40, + time_embed_dim: int = 512, + ): + super().__init__() + + self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + out_dim=dim, + bias=True, + qk_norm="layer_norm", + elementwise_affine=False, + eps=1e-6, + processor=CogVideoXAttnProcessor2_0(), + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + emb: torch.Tensor, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + # norm & modulate + ( + norm_hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + norm_encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = self.norm1(hidden_states, encoder_hidden_states, emb) + + # attention + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states + ) + + hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states + encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states + + # norm & modulate + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + # feed-forward + norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + ff_output = self.ff(norm_hidden_states) + + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length] + + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + return hidden_states, encoder_hidden_states + + +class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): + r""" + The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay + Diffusion](https://huggingface.co/papers/2403.05121). + + Args: + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + num_layers (`int`, defaults to `30`): + The number of layers of Transformer blocks to use. + attention_head_dim (`int`, defaults to `40`): + The number of channels in each head. + num_attention_heads (`int`, defaults to `64`): + The number of heads to use for multi-head attention. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + time_embed_dim (`int`, defaults to `512`): + Output dimension of timestep embeddings. + condition_dim (`int`, defaults to `256`): + The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size, + crop_coords). + pos_embed_max_size (`int`, defaults to `128`): + The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added + to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128 + means that the maximum supported height and width for image generation is `128 * vae_scale_factor * + patch_size => 128 * 8 * 2 => 2048`. + sample_size (`int`, defaults to `128`): + The base resolution of input latents. If height/width is not provided during generation, this value is used + to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024` + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + num_layers: int = 30, + attention_head_dim: int = 40, + num_attention_heads: int = 64, + out_channels: int = 16, + text_embed_dim: int = 4096, + time_embed_dim: int = 512, + condition_dim: int = 256, + pos_embed_max_size: int = 128, + sample_size: int = 128, + ): + super().__init__() + self.out_channels = out_channels + self.inner_dim = num_attention_heads * attention_head_dim + + # CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords + # Each of these are sincos embeddings of shape 2 * condition_dim + self.pooled_projection_dim = 3 * 2 * condition_dim + + self.patch_embed = CogView3PlusPatchEmbed( + in_channels=in_channels, + hidden_size=self.inner_dim, + patch_size=patch_size, + text_hidden_size=text_embed_dim, + pos_embed_max_size=pos_embed_max_size, + ) + + self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings( + embedding_dim=time_embed_dim, + condition_dim=condition_dim, + pooled_projection_dim=self.pooled_projection_dim, + timesteps_dim=self.inner_dim, + ) + + self.transformer_blocks = nn.ModuleList( + [ + CogView3PlusTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + embedding_dim=self.inner_dim, + conditioning_embedding_dim=time_embed_dim, + elementwise_affine=False, + eps=1e-6, + ) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + original_size: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`CogView3PlusTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor`): + Input `hidden_states` of shape `(batch size, channel, height, width)`. + encoder_hidden_states (`torch.Tensor`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) of shape + `(batch_size, sequence_len, text_embed_dim)` + timestep (`torch.LongTensor`): + Used to indicate denoising step. + original_size (`torch.Tensor`): + CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`torch.Tensor`): + CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crop_coords (`torch.Tensor`): + CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + `torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]: + The denoised latents using provided inputs as conditioning. + """ + height, width = hidden_states.shape[-2:] + text_seq_length = encoder_hidden_states.shape[1] + + hidden_states = self.patch_embed( + hidden_states, encoder_hidden_states + ) # takes care of adding positional embeddings too. + emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype) + + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + for index_block, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=emb, + ) + + hidden_states = self.norm_out(hidden_states, emb) + hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels) + + # unpatchify + patch_size = self.config.patch_size + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size) + ) + hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 45a868eb5810..d7ff34310beb 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -145,6 +145,7 @@ "CogVideoXImageToVideoPipeline", "CogVideoXVideoToVideoPipeline", ] + _import_structure["cogview3"] = ["CogView3PlusPipeline"] _import_structure["controlnet"].extend( [ "BlipDiffusionControlNetPipeline", @@ -470,6 +471,7 @@ from .aura_flow import AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline from .cogvideo import CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline + from .cogview3 import CogView3PlusPipeline from .controlnet import ( BlipDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline, diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index ebeaf9f8aeec..0214d7dd6f3c 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -20,6 +20,7 @@ from ..configuration_utils import ConfigMixin from ..utils import is_sentencepiece_available from .aura_flow import AuraFlowPipeline +from .cogview3 import CogView3PlusPipeline from .controlnet import ( StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, @@ -119,6 +120,7 @@ ("flux", FluxPipeline), ("flux-controlnet", FluxControlNetPipeline), ("lumina", LuminaText2ImgPipeline), + ("cogview3", CogView3PlusPipeline), ] ) diff --git a/src/diffusers/pipelines/cogview3/__init__.py b/src/diffusers/pipelines/cogview3/__init__.py new file mode 100644 index 000000000000..50895251ba0b --- /dev/null +++ b/src/diffusers/pipelines/cogview3/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["CogView3PlusPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_cogview3plus"] = ["CogView3PlusPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_cogview3plus import CogView3PlusPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py new file mode 100644 index 000000000000..7ae86421c45e --- /dev/null +++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py @@ -0,0 +1,674 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, CogView3PlusTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import CogView3PipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogView3PlusPipeline + + >>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3B", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + >>> image.save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogView3PlusPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using CogView3Plus. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogView3Plus uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogView3PlusTransformer2DModel`]): + A text conditioned `CogView3PlusTransformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: CogView3PlusTransformer2DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds with num_videos_per_prompt->num_images_per_prompt + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 224, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + max_sequence_length (`int`, defaults to `224`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt is None: + negative_prompt_embeds = prompt_embeds.new_zeros(prompt_embeds.shape) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 5.0, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 224, + ) -> Union[CogView3PipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. If not provided, it is set to 1024. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. If not provided it is set to 1024. + num_inference_steps (`int`, *optional*, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to `1`): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `224`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + + Examples: + + Returns: + [`~pipelines.cogview3.pipeline_cogview3plus.CogView3PipelineOutput`] or `tuple`: + [`~pipelines.cogview3.pipeline_cogview3plus.CogView3PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare additional timestep conditions + original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype) + target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype) + crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype) + + if self.do_classifier_free_guidance: + original_size = torch.cat([original_size, original_size]) + target_size = torch.cat([target_size, target_size]) + crops_coords_top_left = torch.cat([crops_coords_top_left, crops_coords_top_left]) + + original_size = original_size.to(device).repeat(batch_size * num_images_per_prompt, 1) + target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1) + crops_coords_top_left = crops_coords_top_left.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return CogView3PipelineOutput(images=image) diff --git a/src/diffusers/pipelines/cogview3/pipeline_output.py b/src/diffusers/pipelines/cogview3/pipeline_output.py new file mode 100644 index 000000000000..3891dd51e691 --- /dev/null +++ b/src/diffusers/pipelines/cogview3/pipeline_output.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class CogView3PipelineOutput(BaseOutput): + """ + Output class for CogView3 pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 1ab946ce7257..eaab67c93b18 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -122,6 +122,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class CogView3PlusTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ConsistencyDecoderVAE(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index fee69d01ebff..d8109eee6d35 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -317,6 +317,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class CogView3PlusPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class CycleDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_cogview3plus.py b/tests/models/transformers/test_models_transformer_cogview3plus.py new file mode 100644 index 000000000000..46612dbd9190 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_cogview3plus.py @@ -0,0 +1,85 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import CogView3PlusTransformer2DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = CogView3PlusTransformer2DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + height = 8 + width = 8 + embedding_dim = 8 + sequence_length = 8 + + hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + original_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device) + target_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device) + crop_coords = torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "original_size": original_size, + "target_size": target_size, + "crop_coords": crop_coords, + "timestep": timestep, + } + + @property + def input_shape(self): + return (1, 4, 8, 8) + + @property + def output_shape(self): + return (1, 4, 8, 8) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "patch_size": 2, + "in_channels": 4, + "num_layers": 1, + "attention_head_dim": 4, + "num_attention_heads": 2, + "out_channels": 4, + "text_embed_dim": 8, + "time_embed_dim": 8, + "condition_dim": 2, + "pos_embed_max_size": 8, + "sample_size": 8, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict diff --git a/tests/pipelines/cogview3/__init__.py b/tests/pipelines/cogview3/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/cogview3/test_cogview3plus.py b/tests/pipelines/cogview3/test_cogview3plus.py new file mode 100644 index 000000000000..8d56552ba5ee --- /dev/null +++ b/tests/pipelines/cogview3/test_cogview3plus.py @@ -0,0 +1,269 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import inspect +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import ( + PipelineTesterMixin, + to_np, +) + + +enable_full_determinism() + + +class CogView3PlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = CogView3PlusPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = CogView3PlusTransformer2DModel( + patch_size=2, + in_channels=4, + num_layers=1, + attention_head_dim=4, + num_attention_heads=2, + out_channels=4, + text_embed_dim=32, # Must match with tiny-random-t5 + time_embed_dim=8, + condition_dim=2, + pos_embed_max_size=8, + sample_size=8, + ) + + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + + torch.manual_seed(0) + scheduler = CogVideoXDDIMScheduler() + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "dance monkey", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 16, + "width": 16, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs)[0] + generated_image = image[0] + + self.assertEqual(generated_image.shape, (3, 16, 16)) + expected_image = torch.randn(3, 16, 16) + max_diff = np.abs(generated_image - expected_image).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + +@slow +@require_torch_gpu +class CogView3PlusPipelineIntegrationTests(unittest.TestCase): + prompt = "A painting of a squirrel eating a burger." + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_cogview3plus(self): + generator = torch.Generator("cpu").manual_seed(0) + + pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3b", torch_dtype=torch.float16) + pipe.enable_model_cpu_offload() + prompt = self.prompt + + images = pipe( + prompt=prompt, + height=1024, + width=1024, + generator=generator, + num_inference_steps=2, + output_type="np", + )[0] + + image = images[0] + expected_image = torch.randn(1, 1024, 1024, 3).numpy() + + max_diff = numpy_cosine_similarity_distance(image, expected_image) + assert max_diff < 1e-3, f"Max diff is too high. got {image}"