diff --git a/examples/community/README.md b/examples/community/README.md index d3223654c23a..f69a81c59baf 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -3561,14 +3561,17 @@ pipe.disable_style_aligned() This pipeline adds experimental support for the image-to-video task using AnimateDiff. Refer to [this](https://github.com/huggingface/diffusers/pull/6328) PR for more examples and results. +This pipeline relies on a "hack" discovered by the community that allows the generation of videos given an input image with AnimateDiff. It works by creating a copy of the image `num_frames` times and progressively adding more noise to the image based on the strength and latent interpolation method. + ```py import torch from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler from diffusers.utils import export_to_gif, load_image +model_id = "SG161222/Realistic_Vision_V5.1_noVAE" adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2") -pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda") -pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace") +pipe = DiffusionPipeline.from_pretrained(model_id, motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda") +pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", beta_schedule="linear", steps_offset=1) image = load_image("snail.png") output = pipe( diff --git a/examples/community/pipeline_animatediff_controlnet.py b/examples/community/pipeline_animatediff_controlnet.py index 1285e7c97a9b..5873ceaa8d70 100644 --- a/examples/community/pipeline_animatediff_controlnet.py +++ b/examples/community/pipeline_animatediff_controlnet.py @@ -24,7 +24,7 @@ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel, UNetMotionModel +from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.unets.unet_motion_model import MotionAdapter from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel @@ -382,6 +382,41 @@ def encode_image(self, image, device, num_images_per_prompt): uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * num_images_per_prompt, dim=0 + ) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + else: + image_embeds = ip_adapter_image_embeds + return image_embeds + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents @@ -767,6 +802,7 @@ def __call__( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[PipelineImageInput] = None, conditioning_frames: Optional[List[PipelineImageInput]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -821,6 +857,9 @@ def __call__( not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. conditioning_frames (`List[PipelineImageInput]`, *optional*): The ControlNet input condition to provide guidance to the `unet` for generation. If multiple ControlNets are specified, images must be passed as a list such that each element of the list can be correctly @@ -965,9 +1004,9 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt + ) if isinstance(controlnet, ControlNetModel): conditioning_frames = self.prepare_image( @@ -1023,7 +1062,11 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] diff --git a/examples/community/pipeline_animatediff_img2video.py b/examples/community/pipeline_animatediff_img2video.py index 826742f9afc8..e77e26592d3e 100644 --- a/examples/community/pipeline_animatediff_img2video.py +++ b/examples/community/pipeline_animatediff_img2video.py @@ -11,9 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# +# Note: +# This pipeline relies on a "hack" discovered by the community that allows +# the generation of videos given an input image with AnimateDiff. It works +# by creating a copy of the image `num_frames` times and progressively adding +# more noise to the image based on the strength and latent interpolation method. import inspect -from dataclasses import dataclass from types import FunctionType from typing import Any, Callable, Dict, List, Optional, Union @@ -25,7 +30,8 @@ from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from diffusers.models.lora import adjust_lora_scale_text_encoder -from diffusers.models.unet_motion_model import MotionAdapter +from diffusers.models.unets.unet_motion_model import MotionAdapter +from diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers import ( DDIMScheduler, @@ -35,7 +41,7 @@ LMSDiscreteScheduler, PNDMScheduler, ) -from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from diffusers.utils.torch_utils import randn_tensor @@ -48,9 +54,10 @@ >>> from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler >>> from diffusers.utils import export_to_gif, load_image + >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE" >>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2") >>> pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda") - >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace") + >>> pipe.scheduler = pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", beta_schedule="linear", steps_offset=1) >>> image = load_image("snail.png") >>> output = pipe(image=image, prompt="A snail moving on the ground", strength=0.8, latent_interpolation_method="slerp") @@ -225,14 +232,9 @@ def retrieve_timesteps( return timesteps, num_inference_steps -@dataclass -class AnimateDiffImgToVideoPipelineOutput(BaseOutput): - frames: Union[torch.Tensor, np.ndarray] - - class AnimateDiffImgToVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): r""" - Pipeline for text-to-video generation. + Pipeline for image-to-video generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). @@ -503,6 +505,41 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * num_images_per_prompt, dim=0 + ) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + else: + image_embeds = ip_adapter_image_embeds + return image_embeds + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents @@ -765,6 +802,7 @@ def __call__( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -818,6 +856,9 @@ def __call__( not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or `np.array`. @@ -842,8 +883,8 @@ def __call__( Examples: Returns: - [`AnimateDiffImgToVideoPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`AnimateDiffImgToVideoPipelineOutput`] is + [`AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ # 0. Default height and width to unet @@ -902,12 +943,9 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_videos_per_prompt, output_hidden_state + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt ) - if do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Preprocess image image = self.image_processor.preprocess(image, height=height, width=width) @@ -936,7 +974,11 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -970,7 +1012,7 @@ def __call__( callback(i, t, latents) if output_type == "latent": - return AnimateDiffImgToVideoPipelineOutput(frames=latents) + return AnimateDiffPipelineOutput(frames=latents) # 10. Post-processing video_tensor = self.decode_latents(latents) @@ -986,4 +1028,4 @@ def __call__( if not return_dict: return (video,) - return AnimateDiffImgToVideoPipelineOutput(frames=video) + return AnimateDiffPipelineOutput(frames=video)