diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md
index 4e1670df7717..e00446857fc5 100644
--- a/docs/source/en/api/pipelines/animatediff.md
+++ b/docs/source/en/api/pipelines/animatediff.md
@@ -25,6 +25,7 @@ The abstract of the paper is the following:
| Pipeline | Tasks | Demo
|---|---|:---:|
| [AnimateDiffPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff.py) | *Text-to-Video Generation with AnimateDiff* |
+| [AnimateDiffVideoToVideoPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py) | *Video-to-Video Generation with AnimateDiff* |
## Available checkpoints
@@ -32,6 +33,8 @@ Motion Adapter checkpoints can be found under [guoyww](https://huggingface.co/gu
## Usage example
+### AnimateDiffPipeline
+
AnimateDiff works with a MotionAdapter checkpoint and a Stable Diffusion model checkpoint. The MotionAdapter is a collection of Motion Modules that are responsible for adding coherent motion across image frames. These modules are applied after the Resnet and Attention blocks in Stable Diffusion UNet.
The following example demonstrates how to use a *MotionAdapter* checkpoint with Diffusers for inference based on StableDiffusion-1.4/1.5.
@@ -98,6 +101,114 @@ AnimateDiff tends to work better with finetuned Stable Diffusion models. If you
+### AnimateDiffVideoToVideoPipeline
+
+AnimateDiff can also be used to generate visually similar videos or enable style/character/background or other edits starting from an initial video, allowing you to seamlessly explore creative possibilities.
+
+```python
+import imageio
+import requests
+import torch
+from diffusers import AnimateDiffVideoToVideoPipeline, DDIMScheduler, MotionAdapter
+from diffusers.utils import export_to_gif
+from io import BytesIO
+from PIL import Image
+
+# Load the motion adapter
+adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
+# load SD 1.5 based finetuned model
+model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
+pipe = AnimateDiffVideoToVideoPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)
+scheduler = DDIMScheduler.from_pretrained(
+ model_id,
+ subfolder="scheduler",
+ clip_sample=False,
+ timestep_spacing="linspace",
+ beta_schedule="linear",
+ steps_offset=1,
+)
+pipe.scheduler = scheduler
+
+# enable memory savings
+pipe.enable_vae_slicing()
+pipe.enable_model_cpu_offload()
+
+# helper function to load videos
+def load_video(file_path: str):
+ images = []
+
+ if file_path.startswith(('http://', 'https://')):
+ # If the file_path is a URL
+ response = requests.get(file_path)
+ response.raise_for_status()
+ content = BytesIO(response.content)
+ vid = imageio.get_reader(content)
+ else:
+ # Assuming it's a local file path
+ vid = imageio.get_reader(file_path)
+
+ for frame in vid:
+ pil_image = Image.fromarray(frame)
+ images.append(pil_image)
+
+ return images
+
+video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif")
+
+output = pipe(
+ video = video,
+ prompt="panda playing a guitar, on a boat, in the ocean, high quality",
+ negative_prompt="bad quality, worse quality",
+ guidance_scale=7.5,
+ num_inference_steps=25,
+ strength=0.5,
+ generator=torch.Generator("cpu").manual_seed(42),
+)
+frames = output.frames[0]
+export_to_gif(frames, "animation.gif")
+```
+
+Here are some sample outputs:
+
+
+
+ | Source Video |
+ Output Video |
+
+
+
+ raccoon playing a guitar
+
+
+ |
+
+ panda playing a guitar
+
+
+ |
+
+
+
+ closeup of margot robbie, fireworks in the background, high quality
+
+
+ |
+
+ closeup of tony stark, robert downey jr, fireworks
+
+
+ |
+
+
+
## Using Motion LoRAs
Motion LoRAs are a collection of LoRAs that work with the `guoyww/animatediff-motion-adapter-v1-5-2` checkpoint. These LoRAs are responsible for adding specific types of motion to the animations.
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 107547abf2f5..931ccc8a4a9d 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -208,6 +208,7 @@
"AmusedInpaintPipeline",
"AmusedPipeline",
"AnimateDiffPipeline",
+ "AnimateDiffVideoToVideoPipeline",
"AudioLDM2Pipeline",
"AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel",
@@ -569,6 +570,7 @@
AmusedInpaintPipeline,
AmusedPipeline,
AnimateDiffPipeline,
+ AnimateDiffVideoToVideoPipeline,
AudioLDM2Pipeline,
AudioLDM2ProjectionModel,
AudioLDM2UNet2DConditionModel,
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 13163f96e059..9d11f78ceee2 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -109,7 +109,10 @@
]
)
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
- _import_structure["animatediff"] = ["AnimateDiffPipeline"]
+ _import_structure["animatediff"] = [
+ "AnimateDiffPipeline",
+ "AnimateDiffVideoToVideoPipeline",
+ ]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
"AudioLDM2Pipeline",
@@ -341,7 +344,7 @@
from ..utils.dummy_torch_and_transformers_objects import *
else:
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
- from .animatediff import AnimateDiffPipeline
+ from .animatediff import AnimateDiffPipeline, AnimateDiffVideoToVideoPipeline
from .audioldm import AudioLDMPipeline
from .audioldm2 import (
AudioLDM2Pipeline,
diff --git a/src/diffusers/pipelines/animatediff/__init__.py b/src/diffusers/pipelines/animatediff/__init__.py
index 503352fec865..35b99a76fd21 100644
--- a/src/diffusers/pipelines/animatediff/__init__.py
+++ b/src/diffusers/pipelines/animatediff/__init__.py
@@ -11,7 +11,7 @@
_dummy_objects = {}
-_import_structure = {}
+_import_structure = {"pipeline_output": ["AnimateDiffPipelineOutput"]}
try:
if not (is_transformers_available() and is_torch_available()):
@@ -21,7 +21,8 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
- _import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline", "AnimateDiffPipelineOutput"]
+ _import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"]
+ _import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -31,7 +32,9 @@
from ...utils.dummy_torch_and_transformers_objects import *
else:
- from .pipeline_animatediff import AnimateDiffPipeline, AnimateDiffPipelineOutput
+ from .pipeline_animatediff import AnimateDiffPipeline
+ from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
+ from .pipeline_output import AnimateDiffPipelineOutput
else:
import sys
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
index 25b664fe3b29..ee1062ee81ff 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
@@ -14,7 +14,6 @@
import inspect
import math
-from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
@@ -37,7 +36,6 @@
)
from ...utils import (
USE_PEFT_BACKEND,
- BaseOutput,
deprecate,
logging,
replace_example_docstring,
@@ -46,6 +44,7 @@
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import AnimateDiffPipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -153,11 +152,6 @@ def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> tor
return x_mixed
-@dataclass
-class AnimateDiffPipelineOutput(BaseOutput):
- frames: Union[torch.Tensor, np.ndarray]
-
-
class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
r"""
Pipeline for text-to-video generation.
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
new file mode 100644
index 000000000000..3d01009cbac7
--- /dev/null
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
@@ -0,0 +1,969 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
+from ...models.lora import adjust_lora_scale_text_encoder
+from ...models.unets.unet_motion_model import MotionAdapter
+from ...schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import AnimateDiffPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import imageio
+ >>> import requests
+ >>> import torch
+ >>> from diffusers import AnimateDiffVideoToVideoPipeline, DDIMScheduler, MotionAdapter
+ >>> from diffusers.utils import export_to_gif
+ >>> from io import BytesIO
+ >>> from PIL import Image
+
+ >>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
+ >>> pipe = AnimateDiffVideoToVideoPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter).to("cuda")
+ >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace")
+
+ >>> def load_video(file_path: str):
+ ... images = []
+ ...
+ ... if file_path.startswith(('http://', 'https://')):
+ ... # If the file_path is a URL
+ ... response = requests.get(file_path)
+ ... response.raise_for_status()
+ ... content = BytesIO(response.content)
+ ... vid = imageio.get_reader(content)
+ ... else:
+ ... # Assuming it's a local file path
+ ... vid = imageio.get_reader(file_path)
+ ...
+ ... for frame in vid:
+ ... pil_image = Image.fromarray(frame)
+ ... images.append(pil_image)
+ ...
+ ... return images
+
+ >>> video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif")
+ >>> output = pipe(video=video, prompt="panda playing a guitar, on a boat, in the ocean, high quality", strength=0.5)
+ >>> frames = output.frames[0]
+ >>> export_to_gif(frames, "animation.gif")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
+def tensor2vid(video: torch.Tensor, processor, output_type="np"):
+ batch_size, channels, num_frames, height, width = video.shape
+ outputs = []
+ for batch_idx in range(batch_size):
+ batch_vid = video[batch_idx].permute(1, 0, 2, 3)
+ batch_output = processor.postprocess(batch_vid, output_type)
+
+ outputs.append(batch_output)
+
+ if output_type == "np":
+ outputs = np.stack(outputs)
+
+ elif output_type == "pt":
+ outputs = torch.stack(outputs)
+
+ elif not output_type == "pil":
+ raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
+
+ return outputs
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class AnimateDiffVideoToVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for video-to-video generation.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer (`CLIPTokenizer`):
+ A [`~transformers.CLIPTokenizer`] to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
+ motion_adapter ([`MotionAdapter`]):
+ A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+ _optional_components = ["feature_extractor", "image_encoder"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ motion_adapter: MotionAdapter,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ feature_extractor: CLIPImageProcessor = None,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ ):
+ super().__init__()
+ unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ motion_adapter=motion_adapter,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+
+ batch_size, channels, num_frames, height, width = latents.shape
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
+
+ image = self.vae.decode(latents).sample
+ video = (
+ image[None, :]
+ .reshape(
+ (
+ batch_size,
+ num_frames,
+ -1,
+ )
+ + image.shape[2:]
+ )
+ .permute(0, 2, 1, 3, 4)
+ )
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ video = video.float()
+ return video
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stages where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ strength,
+ height,
+ width,
+ video=None,
+ latents=None,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if video is not None and latents is not None:
+ raise ValueError("Only one of `video` or `latents` should be provided")
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(
+ self,
+ video,
+ height,
+ width,
+ num_channels_latents,
+ batch_size,
+ timestep,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # video must be a list of list of images
+ # the outer list denotes having multiple videos as input, whereas inner list means the frames of the video
+ # as a list of images
+ if not isinstance(video[0], list):
+ video = [video]
+ if latents is None:
+ video = torch.cat(
+ [self.image_processor.preprocess(vid, height=height, width=width).unsqueeze(0) for vid in video], dim=0
+ )
+ video = video.to(device=device, dtype=dtype)
+ num_frames = video.shape[1]
+ else:
+ num_frames = latents.shape[2]
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_frames,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.config.force_upcast:
+ video = video.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ if len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ init_latents = [
+ retrieve_latents(self.vae.encode(video[i]), generator=generator[i]).unsqueeze(0)
+ for i in range(batch_size)
+ ]
+ else:
+ init_latents = [
+ retrieve_latents(self.vae.encode(vid), generator=generator).unsqueeze(0) for vid in video
+ ]
+
+ init_latents = torch.cat(init_latents, dim=0)
+
+ # restore vae to original dtype
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ init_latents = init_latents.to(dtype)
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ error_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Please make sure to update your script to pass as many initial images as text prompts"
+ )
+ raise ValueError(error_message)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4)
+ else:
+ if shape != latents.shape:
+ # [B, C, F, H, W]
+ raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
+ latents = latents.to(device, dtype=dtype)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ video: List[List[PipelineImageInput]] = None,
+ prompt: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ guidance_scale: float = 7.5,
+ strength: float = 0.8,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ video (`List[PipelineImageInput]`):
+ The input video to condition the generation on. Must be a list of images/frames of the video.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated video.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated video.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
+ expense of slower inference.
+ strength (`float`, *optional*, defaults to 0.8):
+ Higher strength leads to more differences between original video and generated video.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
+ `(batch_size, num_channel, num_frames, height, width)`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
+ `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`AnimateDiffPipelineOutput`] instead
+ of a plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeine class.
+
+ Examples:
+
+ Returns:
+ [`AnimateDiffPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
+ """
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ strength=strength,
+ height=height,
+ width=width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ video=video,
+ latents=latents,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ if ip_adapter_image is not None:
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
+ )
+ if self.do_classifier_free_guidance:
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ video=video,
+ height=height,
+ width=width,
+ num_channels_latents=num_channels_latents,
+ batch_size=batch_size * num_videos_per_prompt,
+ timestep=latent_timestep,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Add image embeds for IP-Adapter
+ added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ ).sample
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ progress_bar.update()
+
+ if output_type == "latent":
+ return AnimateDiffPipelineOutput(frames=latents)
+
+ # 9. Post-processing
+ video_tensor = self.decode_latents(latents)
+
+ if output_type == "pt":
+ video = video_tensor
+ else:
+ video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
+
+ # 10. Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return AnimateDiffPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/animatediff/pipeline_output.py b/src/diffusers/pipelines/animatediff/pipeline_output.py
new file mode 100644
index 000000000000..79a399fbc3a6
--- /dev/null
+++ b/src/diffusers/pipelines/animatediff/pipeline_output.py
@@ -0,0 +1,22 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+import torch
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class AnimateDiffPipelineOutput(BaseOutput):
+ r"""
+ Output class for AnimateDiff pipelines.
+
+ Args:
+ frames (`List[List[PIL.Image.Image]]` or `torch.Tensor` or `np.ndarray`):
+ List of PIL Images of length `batch_size` or torch.Tensor or np.ndarray of shape
+ `(batch_size, num_frames, height, width, num_channels)`.
+ """
+
+ frames: Union[List[List[PIL.Image.Image]], torch.Tensor, np.ndarray]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 2eb9599658d9..e0d5c77d0e8c 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -92,6 +92,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class AudioLDM2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py
new file mode 100644
index 000000000000..3226bdb3ca6e
--- /dev/null
+++ b/tests/pipelines/animatediff/test_animatediff_video2video.py
@@ -0,0 +1,269 @@
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
+
+import diffusers
+from diffusers import (
+ AnimateDiffVideoToVideoPipeline,
+ AutoencoderKL,
+ DDIMScheduler,
+ MotionAdapter,
+ UNet2DConditionModel,
+ UNetMotionModel,
+)
+from diffusers.utils import is_xformers_available, logging
+from diffusers.utils.testing_utils import torch_device
+
+from ..pipeline_params import TEXT_TO_IMAGE_PARAMS, VIDEO_TO_VIDEO_BATCH_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+def to_np(tensor):
+ if isinstance(tensor, torch.Tensor):
+ tensor = tensor.detach().cpu().numpy()
+
+ return tensor
+
+
+class AnimateDiffVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = AnimateDiffVideoToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS
+ batch_params = VIDEO_TO_VIDEO_BATCH_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=32,
+ norm_num_groups=2,
+ )
+ scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="linear",
+ clip_sample=False,
+ )
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ )
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ motion_adapter = MotionAdapter(
+ block_out_channels=(32, 64),
+ motion_layers_per_block=2,
+ motion_norm_num_groups=2,
+ motion_num_attention_heads=4,
+ )
+
+ components = {
+ "unet": unet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "motion_adapter": motion_adapter,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "feature_extractor": None,
+ "image_encoder": None,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ video_height = 32
+ video_width = 32
+ video_num_frames = 2
+ video = [Image.new("RGB", (video_width, video_height))] * video_num_frames
+
+ inputs = {
+ "video": video,
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 7.5,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_motion_unet_loading(self):
+ components = self.get_dummy_components()
+ pipe = AnimateDiffVideoToVideoPipeline(**components)
+
+ assert isinstance(pipe.unet, UNetMotionModel)
+
+ @unittest.skip("Attention slicing is not enabled in this pipeline")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ def test_inference_batch_single_identical(
+ self,
+ batch_size=2,
+ expected_max_diff=1e-4,
+ additional_params_copy_to_batched_inputs=["num_inference_steps"],
+ ):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for components in pipe.components.values():
+ if hasattr(components, "set_default_attn_processor"):
+ components.set_default_attn_processor()
+
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ inputs = self.get_dummy_inputs(torch_device)
+ # Reset generator in case it is has been used in self.get_dummy_inputs
+ inputs["generator"] = self.get_generator(0)
+
+ logger = logging.get_logger(pipe.__module__)
+ logger.setLevel(level=diffusers.logging.FATAL)
+
+ # batchify inputs
+ batched_inputs = {}
+ batched_inputs.update(inputs)
+
+ for name in self.batch_params:
+ if name not in inputs:
+ continue
+
+ value = inputs[name]
+ if name == "prompt":
+ len_prompt = len(value)
+ batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
+ batched_inputs[name][-1] = 100 * "very long"
+
+ else:
+ batched_inputs[name] = batch_size * [value]
+
+ if "generator" in inputs:
+ batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
+
+ if "batch_size" in inputs:
+ batched_inputs["batch_size"] = batch_size
+
+ for arg in additional_params_copy_to_batched_inputs:
+ batched_inputs[arg] = inputs[arg]
+
+ output = pipe(**inputs)
+ output_batch = pipe(**batched_inputs)
+
+ assert output_batch[0].shape[0] == batch_size
+
+ max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
+ assert max_diff < expected_max_diff
+
+ @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ def test_to_device(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ pipe.to("cpu")
+ # pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components
+ model_devices = [
+ component.device.type for component in pipe.components.values() if hasattr(component, "device")
+ ]
+ self.assertTrue(all(device == "cpu" for device in model_devices))
+
+ output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
+ self.assertTrue(np.isnan(output_cpu).sum() == 0)
+
+ pipe.to("cuda")
+ model_devices = [
+ component.device.type for component in pipe.components.values() if hasattr(component, "device")
+ ]
+ self.assertTrue(all(device == "cuda" for device in model_devices))
+
+ output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
+ self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
+
+ def test_to_dtype(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ # pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
+
+ pipe.to(torch_dtype=torch.float16)
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
+
+ def test_prompt_embeds(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs.pop("prompt")
+ inputs["prompt_embeds"] = torch.randn((1, 4, 32), device=torch_device)
+ pipe(**inputs)
+
+ @unittest.skipIf(
+ torch_device != "cuda" or not is_xformers_available(),
+ reason="XFormers attention is only available with CUDA and `xformers` installed",
+ )
+ def test_xformers_attention_forwardGenerator_pass(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_without_offload = pipe(**inputs).frames[0]
+ output_without_offload = (
+ output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload
+ )
+
+ pipe.enable_xformers_memory_efficient_attention()
+ inputs = self.get_dummy_inputs(torch_device)
+ output_with_offload = pipe(**inputs).frames[0]
+ output_with_offload = (
+ output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload
+ )
+
+ max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
+ self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
diff --git a/tests/pipelines/pipeline_params.py b/tests/pipelines/pipeline_params.py
index f5be787656c7..4e2c4dcdd9cb 100644
--- a/tests/pipelines/pipeline_params.py
+++ b/tests/pipelines/pipeline_params.py
@@ -125,3 +125,5 @@
TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
+
+VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"])