diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index cb856fe0acfc..12afe4d2c436 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -133,6 +133,60 @@ output = pipe( export_to_video(output, "wan-i2v.mp4", fps=16) ``` +### First and Last Frame Interpolation + +```python +import numpy as np +import torch +import torchvision.transforms.functional as TF +from diffusers import AutoencoderKLWan, WanImageToVideoPipeline +from diffusers.utils import export_to_video, load_image +from transformers import CLIPVisionModel + + +model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers" +image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32) +vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) +pipe = WanImageToVideoPipeline.from_pretrained( + model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + +first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png") +last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png") + +def aspect_ratio_resize(image, pipe, max_area=720 * 1280): + aspect_ratio = image.height / image.width + mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + image = image.resize((width, height)) + return image, height, width + +def center_crop_resize(image, height, width): + # Calculate resize ratio to match first frame dimensions + resize_ratio = max(width / image.width, height / image.height) + + # Resize the image + width = round(image.width * resize_ratio) + height = round(image.height * resize_ratio) + size = [width, height] + image = TF.center_crop(image, size) + + return image, height, width + +first_frame, height, width = aspect_ratio_resize(first_frame, pipe) +if last_frame.size != first_frame.size: + last_frame, _, _ = center_crop_resize(last_frame, height, width) + +prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." + +output = pipe( + image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.5 +).frames[0] +export_to_video(output, "output.mp4", fps=16) +``` + ### Video to Video Generation ```python diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 0b2fa872487e..f9b85bf54cc8 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -39,6 +39,24 @@ "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", "img_emb.proj.4": "condition_embedder.image_embedder.norm2", + # for the FLF2V model + "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed", + # Add attention component mappings + "self_attn.q": "attn1.to_q", + "self_attn.k": "attn1.to_k", + "self_attn.v": "attn1.to_v", + "self_attn.o": "attn1.to_out.0", + "self_attn.norm_q": "attn1.norm_q", + "self_attn.norm_k": "attn1.norm_k", + "cross_attn.q": "attn2.to_q", + "cross_attn.k": "attn2.to_k", + "cross_attn.v": "attn2.to_v", + "cross_attn.o": "attn2.to_out.0", + "cross_attn.norm_q": "attn2.norm_q", + "cross_attn.norm_k": "attn2.norm_k", + "attn2.to_k_img": "attn2.add_k_proj", + "attn2.to_v_img": "attn2.add_v_proj", + "attn2.norm_k_img": "attn2.norm_added_k", } TRANSFORMER_SPECIAL_KEYS_REMAP = {} @@ -135,6 +153,28 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "text_dim": 4096, }, } + elif model_type == "Wan-FLF2V-14B-720P": + config = { + "model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder + "diffusers_config": { + "image_dim": 1280, + "added_kv_proj_dim": 5120, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 36, + "num_attention_heads": 40, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + "rope_max_seq_len": 1024, + "pos_embed_seq_len": 257 * 2, + }, + } return config @@ -397,7 +437,7 @@ def get_args(): prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=3.0 ) - if "I2V" in args.model_type: + if "I2V" in args.model_type or "FLF2V" in args.model_type: image_encoder = CLIPVisionModelWithProjection.from_pretrained( "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16 ) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index aa03e97093aa..757cbd65c6bf 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -49,8 +49,10 @@ def __call__( ) -> torch.Tensor: encoder_hidden_states_img = None if attn.add_k_proj is not None: - encoder_hidden_states_img = encoder_hidden_states[:, :257] - encoder_hidden_states = encoder_hidden_states[:, 257:] + # 512 is the context length of the text encoder, hardcoded for now + image_context_length = encoder_hidden_states.shape[1] - 512 + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] if encoder_hidden_states is None: encoder_hidden_states = hidden_states @@ -108,14 +110,23 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): class WanImageEmbedding(torch.nn.Module): - def __init__(self, in_features: int, out_features: int): + def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None): super().__init__() self.norm1 = FP32LayerNorm(in_features) self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu") self.norm2 = FP32LayerNorm(out_features) + if pos_embed_seq_len is not None: + self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features)) + else: + self.pos_embed = None def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: + if self.pos_embed is not None: + batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape + encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim) + encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed + hidden_states = self.norm1(encoder_hidden_states_image) hidden_states = self.ff(hidden_states) hidden_states = self.norm2(hidden_states) @@ -130,6 +141,7 @@ def __init__( time_proj_dim: int, text_embed_dim: int, image_embed_dim: Optional[int] = None, + pos_embed_seq_len: Optional[int] = None, ): super().__init__() @@ -141,7 +153,7 @@ def __init__( self.image_embedder = None if image_embed_dim is not None: - self.image_embedder = WanImageEmbedding(image_embed_dim, dim) + self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len) def forward( self, @@ -350,6 +362,7 @@ def __init__( image_dim: Optional[int] = None, added_kv_proj_dim: Optional[int] = None, rope_max_seq_len: int = 1024, + pos_embed_seq_len: Optional[int] = None, ) -> None: super().__init__() @@ -368,6 +381,7 @@ def __init__( time_proj_dim=inner_dim * 6, text_embed_dim=text_dim, image_embed_dim=image_dim, + pos_embed_seq_len=pos_embed_seq_len, ) # 3. Transformer blocks diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 20ad84cb90d0..86d5496d1623 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -380,6 +380,7 @@ def prepare_latents( device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, + last_image: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 latent_height = height // self.vae_scale_factor_spatial @@ -398,9 +399,16 @@ def prepare_latents( latents = latents.to(device=device, dtype=dtype) image = image.unsqueeze(2) - video_condition = torch.cat( - [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 - ) + if last_image is None: + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) + else: + last_image = last_image.unsqueeze(2) + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], + dim=2, + ) video_condition = video_condition.to(device=device, dtype=dtype) latents_mean = ( @@ -424,7 +432,11 @@ def prepare_latents( latent_condition = (latent_condition - latents_mean) * latents_std mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) - mask_lat_size[:, :, list(range(1, num_frames))] = 0 + + if last_image is None: + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + else: + mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 first_frame_mask = mask_lat_size[:, :, 0:1] first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) @@ -476,6 +488,7 @@ def __call__( prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, image_embeds: Optional[torch.Tensor] = None, + last_image: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -620,7 +633,10 @@ def __call__( negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) if image_embeds is None: - image_embeds = self.encode_image(image, device) + if last_image is None: + image_embeds = self.encode_image(image, device) + else: + image_embeds = self.encode_image([image, last_image], device) image_embeds = image_embeds.repeat(batch_size, 1, 1) image_embeds = image_embeds.to(transformer_dtype) @@ -631,6 +647,10 @@ def __call__( # 5. Prepare latent variables num_channels_latents = self.vae.config.z_dim image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + if last_image is not None: + last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( + device, dtype=torch.float32 + ) latents, condition = self.prepare_latents( image, batch_size * num_videos_per_prompt, @@ -642,6 +662,7 @@ def __call__( device, generator, latents, + last_image, ) # 6. Denoising loop diff --git a/tests/pipelines/wan/test_wan_image_to_video.py b/tests/pipelines/wan/test_wan_image_to_video.py index 53fa37dfae99..ffcd4d31b846 100644 --- a/tests/pipelines/wan/test_wan_image_to_video.py +++ b/tests/pipelines/wan/test_wan_image_to_video.py @@ -160,3 +160,90 @@ def test_attention_slicing_forward_pass(self): @unittest.skip("TODO: revisit failing as it requires a very high threshold to pass") def test_inference_batch_single_identical(self): pass + + +class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests): + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + # TODO: impl FlowDPMSolverMultistepScheduler + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = WanTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=36, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + image_dim=4, + pos_embed_seq_len=2 * (4 * 4 + 1), + ) + + torch.manual_seed(0) + image_encoder_config = CLIPVisionConfig( + hidden_size=4, + projection_dim=4, + num_hidden_layers=2, + num_attention_heads=2, + image_size=4, + intermediate_size=16, + patch_size=1, + ) + image_encoder = CLIPVisionModelWithProjection(image_encoder_config) + + torch.manual_seed(0) + image_processor = CLIPImageProcessor(crop_size=4, size=4) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "image_encoder": image_encoder, + "image_processor": image_processor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + image_height = 16 + image_width = 16 + image = Image.new("RGB", (image_width, image_height)) + last_image = Image.new("RGB", (image_width, image_height)) + inputs = { + "image": image, + "last_image": last_image, + "prompt": "dance monkey", + "negative_prompt": "negative", + "height": image_height, + "width": image_width, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs