1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ #
15+ # Note:
16+ # This pipeline relies on a "hack" discovered by the community that allows
17+ # the generation of videos given an input image with AnimateDiff. It works
18+ # by creating a copy of the image `num_frames` times and progressively adding
19+ # more noise to the image based on the strength and latent interpolation method.
1420
1521import inspect
16- from dataclasses import dataclass
1722from types import FunctionType
1823from typing import Any , Callable , Dict , List , Optional , Union
1924
2530from diffusers .loaders import IPAdapterMixin , LoraLoaderMixin , TextualInversionLoaderMixin
2631from diffusers .models import AutoencoderKL , ImageProjection , UNet2DConditionModel , UNetMotionModel
2732from diffusers .models .lora import adjust_lora_scale_text_encoder
28- from diffusers .models .unet_motion_model import MotionAdapter
33+ from diffusers .models .unets .unet_motion_model import MotionAdapter
34+ from diffusers .pipelines .animatediff .pipeline_output import AnimateDiffPipelineOutput
2935from diffusers .pipelines .pipeline_utils import DiffusionPipeline
3036from diffusers .schedulers import (
3137 DDIMScheduler ,
3541 LMSDiscreteScheduler ,
3642 PNDMScheduler ,
3743)
38- from diffusers .utils import USE_PEFT_BACKEND , BaseOutput , logging , scale_lora_layers , unscale_lora_layers
44+ from diffusers .utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
3945from diffusers .utils .torch_utils import randn_tensor
4046
4147
4854 >>> from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler
4955 >>> from diffusers.utils import export_to_gif, load_image
5056
57+ >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
5158 >>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
5259 >>> pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
53- >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear ", steps_offset=1, clip_sample=False, timespace_spacing ="linspace")
60+ >>> pipe.scheduler = pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler ", clip_sample=False, timestep_spacing ="linspace", beta_schedule="linear", steps_offset=1 )
5461
5562 >>> image = load_image("snail.png")
5663 >>> output = pipe(image=image, prompt="A snail moving on the ground", strength=0.8, latent_interpolation_method="slerp")
@@ -225,14 +232,9 @@ def retrieve_timesteps(
225232 return timesteps , num_inference_steps
226233
227234
228- @dataclass
229- class AnimateDiffImgToVideoPipelineOutput (BaseOutput ):
230- frames : Union [torch .Tensor , np .ndarray ]
231-
232-
233235class AnimateDiffImgToVideoPipeline (DiffusionPipeline , TextualInversionLoaderMixin , IPAdapterMixin , LoraLoaderMixin ):
234236 r"""
235- Pipeline for text -to-video generation.
237+ Pipeline for image -to-video generation.
236238
237239 This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
238240 implemented for all pipelines (downloading, saving, running on a particular device, etc.).
@@ -503,6 +505,41 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
503505
504506 return image_embeds , uncond_image_embeds
505507
508+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
509+ def prepare_ip_adapter_image_embeds (
510+ self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt
511+ ):
512+ if ip_adapter_image_embeds is None :
513+ if not isinstance (ip_adapter_image , list ):
514+ ip_adapter_image = [ip_adapter_image ]
515+
516+ if len (ip_adapter_image ) != len (self .unet .encoder_hid_proj .image_projection_layers ):
517+ raise ValueError (
518+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got { len (ip_adapter_image )} images and { len (self .unet .encoder_hid_proj .image_projection_layers )} IP Adapters."
519+ )
520+
521+ image_embeds = []
522+ for single_ip_adapter_image , image_proj_layer in zip (
523+ ip_adapter_image , self .unet .encoder_hid_proj .image_projection_layers
524+ ):
525+ output_hidden_state = not isinstance (image_proj_layer , ImageProjection )
526+ single_image_embeds , single_negative_image_embeds = self .encode_image (
527+ single_ip_adapter_image , device , 1 , output_hidden_state
528+ )
529+ single_image_embeds = torch .stack ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
530+ single_negative_image_embeds = torch .stack (
531+ [single_negative_image_embeds ] * num_images_per_prompt , dim = 0
532+ )
533+
534+ if self .do_classifier_free_guidance :
535+ single_image_embeds = torch .cat ([single_negative_image_embeds , single_image_embeds ])
536+ single_image_embeds = single_image_embeds .to (device )
537+
538+ image_embeds .append (single_image_embeds )
539+ else :
540+ image_embeds = ip_adapter_image_embeds
541+ return image_embeds
542+
506543 # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
507544 def decode_latents (self , latents ):
508545 latents = 1 / self .vae .config .scaling_factor * latents
@@ -765,6 +802,7 @@ def __call__(
765802 prompt_embeds : Optional [torch .FloatTensor ] = None ,
766803 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
767804 ip_adapter_image : Optional [PipelineImageInput ] = None ,
805+ ip_adapter_image_embeds : Optional [PipelineImageInput ] = None ,
768806 output_type : Optional [str ] = "pil" ,
769807 return_dict : bool = True ,
770808 callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
@@ -818,6 +856,9 @@ def __call__(
818856 not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
819857 ip_adapter_image: (`PipelineImageInput`, *optional*):
820858 Optional image input to work with IP Adapters.
859+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
860+ Pre-generated image embeddings for IP-Adapter. If not
861+ provided, embeddings are computed from the `ip_adapter_image` input argument.
821862 output_type (`str`, *optional*, defaults to `"pil"`):
822863 The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
823864 `np.array`.
@@ -842,8 +883,8 @@ def __call__(
842883 Examples:
843884
844885 Returns:
845- [`AnimateDiffImgToVideoPipelineOutput `] or `tuple`:
846- If `return_dict` is `True`, [`AnimateDiffImgToVideoPipelineOutput `] is
886+ [`AnimateDiffPipelineOutput `] or `tuple`:
887+ If `return_dict` is `True`, [`AnimateDiffPipelineOutput `] is
847888 returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
848889 """
849890 # 0. Default height and width to unet
@@ -902,12 +943,9 @@ def __call__(
902943 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ])
903944
904945 if ip_adapter_image is not None :
905- output_hidden_state = False if isinstance (self .unet .encoder_hid_proj , ImageProjection ) else True
906- image_embeds , negative_image_embeds = self .encode_image (
907- ip_adapter_image , device , num_videos_per_prompt , output_hidden_state
946+ image_embeds = self .prepare_ip_adapter_image_embeds (
947+ ip_adapter_image , ip_adapter_image_embeds , device , batch_size * num_videos_per_prompt
908948 )
909- if do_classifier_free_guidance :
910- image_embeds = torch .cat ([negative_image_embeds , image_embeds ])
911949
912950 # 4. Preprocess image
913951 image = self .image_processor .preprocess (image , height = height , width = width )
@@ -936,7 +974,11 @@ def __call__(
936974 extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
937975
938976 # 8. Add image embeds for IP-Adapter
939- added_cond_kwargs = {"image_embeds" : image_embeds } if ip_adapter_image is not None else None
977+ added_cond_kwargs = (
978+ {"image_embeds" : image_embeds }
979+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
980+ else None
981+ )
940982
941983 # 9. Denoising loop
942984 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
@@ -970,7 +1012,7 @@ def __call__(
9701012 callback (i , t , latents )
9711013
9721014 if output_type == "latent" :
973- return AnimateDiffImgToVideoPipelineOutput (frames = latents )
1015+ return AnimateDiffPipelineOutput (frames = latents )
9741016
9751017 # 10. Post-processing
9761018 video_tensor = self .decode_latents (latents )
@@ -986,4 +1028,4 @@ def __call__(
9861028 if not return_dict :
9871029 return (video ,)
9881030
989- return AnimateDiffImgToVideoPipelineOutput (frames = video )
1031+ return AnimateDiffPipelineOutput (frames = video )
0 commit comments