From f7f511769db0f7d3d844dd9a7a48015bd4f3dbf7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Oct 2024 17:05:16 +0200 Subject: [PATCH 1/7] cogvideox-fun control --- docs/source/en/api/pipelines/cogvideox.md | 10 + src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 3 +- src/diffusers/pipelines/cogvideo/__init__.py | 2 + .../pipeline_cogvideox_fun_control.py | 795 ++++++++++++++++++ .../cogvideo/test_cogvideox_fun_control.py | 328 ++++++++ 6 files changed, 1139 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py create mode 100644 tests/pipelines/cogvideo/test_cogvideox_fun_control.py diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 4cde7a111ae6..58d7a0ef1017 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -36,6 +36,10 @@ There are two models available that can be used with the text-to-video and video There is one model available that can be used with the image-to-video CogVideoX pipeline: - [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`. +There are two models that allow controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team): +- [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `bf16`. +- [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `bf16`. + ## Inference Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency. @@ -118,6 +122,12 @@ It is also worth noting that torchao quantization is fully compatible with [torc - all - __call__ +## CogVideoXFunControlPipeline + +[[autodoc]] CogVideoXFunControlPipeline + - all + - __call__ + ## CogVideoXPipelineOutput [[autodoc]] pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index fadb234c6e10..f2eaae000456 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -258,6 +258,7 @@ "CogVideoXImageToVideoPipeline", "CogVideoXPipeline", "CogVideoXVideoToVideoPipeline", + "CogVideoXFunControlPipeline", "CycleDiffusionPipeline", "FluxControlNetImg2ImgPipeline", "FluxControlNetInpaintPipeline", @@ -711,6 +712,7 @@ CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline, + CogVideoXFunControlPipeline, CycleDiffusionPipeline, FluxControlNetImg2ImgPipeline, FluxControlNetInpaintPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 45a868eb5810..9ee457bb8d93 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -144,6 +144,7 @@ "CogVideoXPipeline", "CogVideoXImageToVideoPipeline", "CogVideoXVideoToVideoPipeline", + "CogVideoXFunControlPipeline", ] _import_structure["controlnet"].extend( [ @@ -469,7 +470,7 @@ ) from .aura_flow import AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline - from .cogvideo import CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline + from .cogvideo import CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline, CogVideoXFunControlPipeline from .controlnet import ( BlipDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline, diff --git a/src/diffusers/pipelines/cogvideo/__init__.py b/src/diffusers/pipelines/cogvideo/__init__.py index bd60fcea9994..2cb5c7a42c0a 100644 --- a/src/diffusers/pipelines/cogvideo/__init__.py +++ b/src/diffusers/pipelines/cogvideo/__init__.py @@ -25,6 +25,7 @@ _import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"] _import_structure["pipeline_cogvideox_image2video"] = ["CogVideoXImageToVideoPipeline"] _import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"] + _import_structure["pipeline_cogvideox_fun_control"] = ["CogVideoXFunControlPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -37,6 +38,7 @@ from .pipeline_cogvideox import CogVideoXPipeline from .pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline + from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline else: import sys diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py new file mode 100644 index 000000000000..8727056fd50c --- /dev/null +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -0,0 +1,795 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI, Alibaba-PAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer +from PIL import Image + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...loaders import CogVideoXLoraLoaderMixin +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoXFunControlPipeline, DDIMScheduler + >>> from diffusers.utils import export_to_video, load_video + + >>> pipe = CogVideoXFunControlPipeline.from_pretrained("alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose", torch_dtype=torch.bfloat16) + >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + >>> pipe.to("cuda") + + >>> control_video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4") + >>> prompt = ( + ... "An astronaut stands triumphantly at the peak of a towering mountain. Panorama of rugged peaks and " + ... "valleys. Very futuristic vibe and animated aesthetic. Highlights of purple and golden colors in " + ... "the scene. The sky is looks like an animated/cartoonish dream of galaxies, nebulae, stars, planets, " + ... "moons, but the remainder of the scene is mostly realistic." + ... ) + + >>> video = pipe(prompt=prompt, control_video=control_video).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for controlled text-to-video generation using CogVideoX Fun. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->vae->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + self.vae_scaling_factor_image = ( + self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_control_latents(self, mask: Optional[torch.Tensor] = None, masked_image: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + if mask is not None: + masks = [] + for i in range(mask.size(0)): + current_mask = mask[i].unsqueeze(0) + current_mask = self.vae.encode(current_mask)[0] + current_mask = current_mask.mode() + masks.append(current_mask) + mask = torch.cat(masks, dim = 0) + mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + mask_pixel_values = [] + for i in range(masked_image.size(0)): + mask_pixel_value = masked_image[i].unsqueeze(0) + mask_pixel_value = self.vae.encode(mask_pixel_value)[0] + mask_pixel_value = mask_pixel_value.mode() + mask_pixel_values.append(mask_pixel_value) + masked_image_latents = torch.cat(mask_pixel_values, dim = 0) + masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae_scaling_factor_image * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + control_video=None, + control_video_latents=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if control_video is not None and control_video_latents is not None: + raise ValueError(f"Cannot pass both `control_video` and `control_video_latents`. Please make sure to pass only one of these parameters.") + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + control_video: Optional[List[Image.Image]] = None, + height: int = 480, + width: int = 720, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + control_video_latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + control_video (`List[PIL.Image.Image]`): + The control video to condition the generation on. Must be a list of images/frames of the video. If not provided, + `control_video_latents` must be provided. + height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The width in pixels of the generated image. This is set to 720 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + control_video_latents (`torch.Tensor`, *optional*): + Pre-generated control latents, sampled from a Gaussian distribution, to be used as inputs for controlled + video generation. If not provided, `control_video` must be provided. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + control_video, + control_video_latents, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if control_video is not None and isinstance(control_video[0], Image.Image): + control_video = [control_video] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels // 2 + num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2) + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + if control_video_latents is None: + control_video = self.video_processor.preprocess_video(control_video, height=height, width=width) + control_video = control_video.to(device=device, dtype=prompt_embeds.dtype) + + _, control_video_latents = self.prepare_control_latents(None, control_video) + control_video_latents = control_video_latents.permute(0, 2, 1, 3, 4) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + latent_control_input = torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents + latent_model_input = torch.cat([latent_model_input, latent_control_input], dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py new file mode 100644 index 000000000000..e6d417fefca2 --- /dev/null +++ b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py @@ -0,0 +1,328 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import inspect +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel +from PIL import Image + +from diffusers import AutoencoderKLCogVideoX, CogVideoXFunControlPipeline, CogVideoXTransformer3DModel, DDIMScheduler +from diffusers.utils.testing_utils import ( + enable_full_determinism, + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, + to_np, +) + + +enable_full_determinism() + + +class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = CogVideoXFunControlPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"control_video"}) + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = CogVideoXTransformer3DModel( + # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings + # But, since we are using tiny-random-t5 here, we need the internal dim of CogVideoXTransformer3DModel + # to be 32. The internal dim is product of num_attention_heads and attention_head_dim + num_attention_heads=4, + attention_head_dim=8, + in_channels=8, + out_channels=4, + time_embed_dim=2, + text_embed_dim=32, # Must match with tiny-random-t5 + num_layers=1, + sample_width=2, # latent width: 2 -> final width: 16 + sample_height=2, # latent height: 2 -> final height: 16 + sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9 + patch_size=2, + temporal_compression_ratio=4, + max_text_seq_length=16, + ) + + torch.manual_seed(0) + vae = AutoencoderKLCogVideoX( + in_channels=3, + out_channels=3, + down_block_types=( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), + up_block_types=( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), + block_out_channels=(8, 8, 8, 8), + latent_channels=4, + layers_per_block=1, + norm_num_groups=2, + temporal_compression_ratio=4, + ) + + torch.manual_seed(0) + scheduler = DDIMScheduler() + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed: int = 0, num_frames: int = 8): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + # Cannot reduce because convolution kernel becomes bigger than sample + height = 16 + width = 16 + + control_video = [Image.new("RGB", (width, height))] * num_frames + + inputs = { + "prompt": "dance monkey", + "negative_prompt": "", + "control_video": control_video, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": height, + "width": width, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (8, 3, 16, 16)) + expected_video = torch.randn(8, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.5): + # NOTE(aryan): This requires a higher expected_max_diff than other CogVideoX pipelines + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_overlap_factor_height=1 / 12, + tile_overlap_factor_width=1 / 12, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + frames = pipe(**inputs).frames # [B, F, C, H, W] + original_image_slice = frames[0, -2:, -1, -3:, -3:] + + pipe.fuse_qkv_projections() + assert check_qkv_fusion_processors_exist( + pipe.transformer + ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_matches_attn_procs_length( + pipe.transformer, pipe.transformer.original_attn_processors + ), "Something wrong with the attention processors concerning the fused QKV projections." + + inputs = self.get_dummy_inputs(device) + frames = pipe(**inputs).frames + image_slice_fused = frames[0, -2:, -1, -3:, -3:] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + frames = pipe(**inputs).frames + image_slice_disabled = frames[0, -2:, -1, -3:, -3:] + + assert np.allclose( + original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 + ), "Fusion of QKV projections shouldn't affect the outputs." + assert np.allclose( + image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 + ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + assert np.allclose( + original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Original outputs should match when fused QKV projections are disabled." From 850512abaee18a10038cd12656709694404e9729 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Oct 2024 17:05:40 +0200 Subject: [PATCH 2/7] make style --- src/diffusers/__init__.py | 4 +- src/diffusers/pipelines/__init__.py | 7 +++- src/diffusers/pipelines/cogvideo/__init__.py | 4 +- .../pipeline_cogvideox_fun_control.py | 41 +++++++++++-------- .../cogvideo/test_cogvideox_fun_control.py | 12 ++---- 5 files changed, 39 insertions(+), 29 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f2eaae000456..9139a2d490b9 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -255,10 +255,10 @@ "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", "CLIPImageProjection", + "CogVideoXFunControlPipeline", "CogVideoXImageToVideoPipeline", "CogVideoXPipeline", "CogVideoXVideoToVideoPipeline", - "CogVideoXFunControlPipeline", "CycleDiffusionPipeline", "FluxControlNetImg2ImgPipeline", "FluxControlNetInpaintPipeline", @@ -709,10 +709,10 @@ AudioLDMPipeline, AuraFlowPipeline, CLIPImageProjection, + CogVideoXFunControlPipeline, CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline, - CogVideoXFunControlPipeline, CycleDiffusionPipeline, FluxControlNetImg2ImgPipeline, FluxControlNetInpaintPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 9ee457bb8d93..04284cac33d4 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -470,7 +470,12 @@ ) from .aura_flow import AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline - from .cogvideo import CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline, CogVideoXFunControlPipeline + from .cogvideo import ( + CogVideoXFunControlPipeline, + CogVideoXImageToVideoPipeline, + CogVideoXPipeline, + CogVideoXVideoToVideoPipeline, + ) from .controlnet import ( BlipDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline, diff --git a/src/diffusers/pipelines/cogvideo/__init__.py b/src/diffusers/pipelines/cogvideo/__init__.py index 2cb5c7a42c0a..e4fa1dda53d3 100644 --- a/src/diffusers/pipelines/cogvideo/__init__.py +++ b/src/diffusers/pipelines/cogvideo/__init__.py @@ -23,9 +23,9 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"] + _import_structure["pipeline_cogvideox_fun_control"] = ["CogVideoXFunControlPipeline"] _import_structure["pipeline_cogvideox_image2video"] = ["CogVideoXImageToVideoPipeline"] _import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"] - _import_structure["pipeline_cogvideox_fun_control"] = ["CogVideoXFunControlPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -36,9 +36,9 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_cogvideox import CogVideoXPipeline + from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline from .pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline - from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline else: import sys diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index 8727056fd50c..92a3bd483b5c 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -18,11 +18,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -from transformers import T5EncoderModel, T5Tokenizer from PIL import Image +from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...image_processor import VaeImageProcessor from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed @@ -44,11 +43,15 @@ >>> from diffusers import CogVideoXFunControlPipeline, DDIMScheduler >>> from diffusers.utils import export_to_video, load_video - >>> pipe = CogVideoXFunControlPipeline.from_pretrained("alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose", torch_dtype=torch.bfloat16) + >>> pipe = CogVideoXFunControlPipeline.from_pretrained( + ... "alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose", torch_dtype=torch.bfloat16 + ... ) >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) >>> pipe.to("cuda") - >>> control_video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4") + >>> control_video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" + ... ) >>> prompt = ( ... "An astronaut stands triumphantly at the peak of a towering mountain. Panorama of rugged peaks and " ... "valleys. Very futuristic vibe and animated aesthetic. Highlights of purple and golden colors in " @@ -350,7 +353,9 @@ def prepare_latents( latents = latents * self.scheduler.init_noise_sigma return latents - def prepare_control_latents(self, mask: Optional[torch.Tensor] = None, masked_image: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + def prepare_control_latents( + self, mask: Optional[torch.Tensor] = None, masked_image: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: if mask is not None: masks = [] for i in range(mask.size(0)): @@ -358,7 +363,7 @@ def prepare_control_latents(self, mask: Optional[torch.Tensor] = None, masked_im current_mask = self.vae.encode(current_mask)[0] current_mask = current_mask.mode() masks.append(current_mask) - mask = torch.cat(masks, dim = 0) + mask = torch.cat(masks, dim=0) mask = mask * self.vae.config.scaling_factor if masked_image is not None: @@ -368,7 +373,7 @@ def prepare_control_latents(self, mask: Optional[torch.Tensor] = None, masked_im mask_pixel_value = self.vae.encode(mask_pixel_value)[0] mask_pixel_value = mask_pixel_value.mode() mask_pixel_values.append(mask_pixel_value) - masked_image_latents = torch.cat(mask_pixel_values, dim = 0) + masked_image_latents = torch.cat(mask_pixel_values, dim=0) masked_image_latents = masked_image_latents * self.vae.config.scaling_factor else: masked_image_latents = None @@ -453,9 +458,11 @@ def check_inputs( f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) - + if control_video is not None and control_video_latents is not None: - raise ValueError(f"Cannot pass both `control_video` and `control_video_latents`. Please make sure to pass only one of these parameters.") + raise ValueError( + "Cannot pass both `control_video` and `control_video_latents`. Please make sure to pass only one of these parameters." + ) def fuse_qkv_projections(self) -> None: r"""Enables fused QKV projections.""" @@ -554,8 +561,8 @@ def __call__( `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). control_video (`List[PIL.Image.Image]`): - The control video to condition the generation on. Must be a list of images/frames of the video. If not provided, - `control_video_latents` must be provided. + The control video to condition the generation on. Must be a list of images/frames of the video. If not + provided, `control_video_latents` must be provided. height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): The height in pixels of the generated image. This is set to 480 by default for the best results. width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): @@ -583,8 +590,8 @@ def __call__( generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. control_video_latents (`torch.Tensor`, *optional*): - Pre-generated control latents, sampled from a Gaussian distribution, to be used as inputs for controlled - video generation. If not provided, `control_video` must be provided. + Pre-generated control latents, sampled from a Gaussian distribution, to be used as inputs for + controlled video generation. If not provided, `control_video` must be provided. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -651,7 +658,7 @@ def __call__( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - + if control_video is not None and isinstance(control_video[0], Image.Image): control_video = [control_video] @@ -698,7 +705,7 @@ def __call__( if control_video_latents is None: control_video = self.video_processor.preprocess_video(control_video, height=height, width=width) control_video = control_video.to(device=device, dtype=prompt_embeds.dtype) - + _, control_video_latents = self.prepare_control_latents(None, control_video) control_video_latents = control_video_latents.permute(0, 2, 1, 3, 4) @@ -725,7 +732,9 @@ def __call__( latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - latent_control_input = torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents + latent_control_input = ( + torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents + ) latent_model_input = torch.cat([latent_model_input, latent_control_input], dim=2) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML diff --git a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py index e6d417fefca2..2a51fc65798c 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py +++ b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py @@ -12,21 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc import inspect import unittest import numpy as np import torch -from transformers import AutoTokenizer, T5EncoderModel from PIL import Image +from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLCogVideoX, CogVideoXFunControlPipeline, CogVideoXTransformer3DModel, DDIMScheduler from diffusers.utils.testing_utils import ( enable_full_determinism, - numpy_cosine_similarity_distance, - require_torch_gpu, - slow, torch_device, ) @@ -123,13 +119,13 @@ def get_dummy_inputs(self, device, seed: int = 0, num_frames: int = 8): generator = torch.manual_seed(seed) else: generator = torch.Generator(device=device).manual_seed(seed) - + # Cannot reduce because convolution kernel becomes bigger than sample height = 16 width = 16 - + control_video = [Image.new("RGB", (width, height))] * num_frames - + inputs = { "prompt": "dance monkey", "negative_prompt": "", From cc043a2b241f6dd38e4a68c7714eccfb836d4917 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Oct 2024 17:06:00 +0200 Subject: [PATCH 3/7] make fix-copies --- .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index fee69d01ebff..242866ab8b75 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -272,6 +272,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class CogVideoXFunControlPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class CogVideoXImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From f8215d37357c9ae071b99b199738c2a4ab01a991 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Oct 2024 02:24:52 +0200 Subject: [PATCH 4/7] karras schedulers --- .../cogvideo/pipeline_cogvideox_fun_control.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index 92a3bd483b5c..3cd9921bb24f 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -26,7 +26,7 @@ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed from ...pipelines.pipeline_utils import DiffusionPipeline -from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...schedulers import KarrasDiffusionSchedulers from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor @@ -182,7 +182,7 @@ def __init__( text_encoder: T5EncoderModel, vae: AutoencoderKLCogVideoX, transformer: CogVideoXTransformer3DModel, - scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + scheduler: KarrasDiffusionSchedulers, ): super().__init__() @@ -761,18 +761,7 @@ def __call__( noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if not isinstance(self.scheduler, CogVideoXDPMScheduler): - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - else: - latents, old_pred_original_sample = self.scheduler.step( - noise_pred, - old_pred_original_sample, - t, - timesteps[i - 1] if i > 0 else None, - latents, - **extra_step_kwargs, - return_dict=False, - ) + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = latents.to(prompt_embeds.dtype) # call the callback, if provided From a0f18ce3dfc3f510f6527fb575e2b8bf98a422f0 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Oct 2024 02:02:31 +0530 Subject: [PATCH 5/7] Update src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- .../pipelines/cogvideo/pipeline_cogvideox_fun_control.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index 3cd9921bb24f..e60861067e4b 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -574,7 +574,7 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to 7.0): + guidance_scale (`float`, *optional*, defaults to 6.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > From d0dbcb799be7d3affa82c89092760a4eec4cb508 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Oct 2024 02:02:36 +0530 Subject: [PATCH 6/7] Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/pipelines/cogvideox.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 58d7a0ef1017..f0f4fd37e6d5 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -36,7 +36,7 @@ There are two models available that can be used with the text-to-video and video There is one model available that can be used with the image-to-video CogVideoX pipeline: - [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`. -There are two models that allow controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team): +There are two models that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team): - [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `bf16`. - [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `bf16`. From 8c58ac294c09b68d9add50325dd9ee9facf9b9c8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Oct 2024 12:35:15 +0200 Subject: [PATCH 7/7] apply suggestions from review --- .../pipelines/cogvideo/pipeline_cogvideox_fun_control.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index e60861067e4b..ba328bdc094f 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -353,6 +353,7 @@ def prepare_latents( latents = latents * self.scheduler.init_noise_sigma return latents + # Adapted from https://github.com/aigc-apps/CogVideoX-Fun/blob/2a93e5c14e02b2b5921d533fd59fc8c0ed69fb24/cogvideox/pipeline/pipeline_cogvideox_control.py#L366 def prepare_control_latents( self, mask: Optional[torch.Tensor] = None, masked_image: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: