Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion scripts/convert_wan_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,24 @@
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
# for the FLF2V model
"img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
# Add attention component mappings
"self_attn.q": "attn1.to_q",
"self_attn.k": "attn1.to_k",
"self_attn.v": "attn1.to_v",
"self_attn.o": "attn1.to_out.0",
"self_attn.norm_q": "attn1.norm_q",
"self_attn.norm_k": "attn1.norm_k",
"cross_attn.q": "attn2.to_q",
"cross_attn.k": "attn2.to_k",
"cross_attn.v": "attn2.to_v",
"cross_attn.o": "attn2.to_out.0",
"cross_attn.norm_q": "attn2.norm_q",
"cross_attn.norm_k": "attn2.norm_k",
"attn2.to_k_img": "attn2.add_k_proj",
"attn2.to_v_img": "attn2.add_v_proj",
"attn2.norm_k_img": "attn2.norm_added_k",
}

TRANSFORMER_SPECIAL_KEYS_REMAP = {}
Expand Down Expand Up @@ -135,6 +153,28 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
"text_dim": 4096,
},
}
elif model_type == "Wan-FLF2V-14B-720P":
config = {
"model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder
"diffusers_config": {
"image_dim": 1280,
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"rope_max_seq_len": 1024,
"pos_embed_seq_len": 257 * 2,
},
}
return config


Expand Down Expand Up @@ -397,7 +437,7 @@ def get_args():
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=3.0
)

if "I2V" in args.model_type:
if "I2V" in args.model_type or "FLF2V" in args.model_type:
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
)
Expand Down
22 changes: 18 additions & 4 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ def __call__(
) -> torch.Tensor:
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
encoder_hidden_states_img = encoder_hidden_states[:, :257]
encoder_hidden_states = encoder_hidden_states[:, 257:]
# 512 is the context length of the text encoder, hardcoded for now
image_context_length = encoder_hidden_states.shape[1] - 512
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
Comment on lines -52 to +55
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not backwards breaking? 👀

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i will test it out :)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states

Expand Down Expand Up @@ -108,14 +110,23 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):


class WanImageEmbedding(torch.nn.Module):
def __init__(self, in_features: int, out_features: int):
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
super().__init__()

self.norm1 = FP32LayerNorm(in_features)
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
self.norm2 = FP32LayerNorm(out_features)
if pos_embed_seq_len is not None:
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, 1280))
else:
self.pos_embed = None

def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
if self.pos_embed is not None:
batch_size, seq_len, h = encoder_hidden_states_image.shape
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, h)
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed

hidden_states = self.norm1(encoder_hidden_states_image)
hidden_states = self.ff(hidden_states)
hidden_states = self.norm2(hidden_states)
Expand All @@ -130,6 +141,7 @@ def __init__(
time_proj_dim: int,
text_embed_dim: int,
image_embed_dim: Optional[int] = None,
pos_embed_seq_len: Optional[int] = None,
):
super().__init__()

Expand All @@ -141,7 +153,7 @@ def __init__(

self.image_embedder = None
if image_embed_dim is not None:
self.image_embedder = WanImageEmbedding(image_embed_dim, dim)
self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)

def forward(
self,
Expand Down Expand Up @@ -350,6 +362,7 @@ def __init__(
image_dim: Optional[int] = None,
added_kv_proj_dim: Optional[int] = None,
rope_max_seq_len: int = 1024,
pos_embed_seq_len: Optional[int] = None,
) -> None:
super().__init__()

Expand All @@ -368,6 +381,7 @@ def __init__(
time_proj_dim=inner_dim * 6,
text_embed_dim=text_dim,
image_embed_dim=image_dim,
pos_embed_seq_len=pos_embed_seq_len,
)

# 3. Transformer blocks
Expand Down
31 changes: 26 additions & 5 deletions src/diffusers/pipelines/wan/pipeline_wan_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def prepare_latents(
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
last_image: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
latent_height = height // self.vae_scale_factor_spatial
Expand All @@ -398,9 +399,16 @@ def prepare_latents(
latents = latents.to(device=device, dtype=dtype)

image = image.unsqueeze(2)
video_condition = torch.cat(
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
)
if last_image is None:
video_condition = torch.cat(
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
)
else:
last_image = last_image.unsqueeze(2)
video_condition = torch.cat(
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
dim=2,
)
video_condition = video_condition.to(device=device, dtype=dtype)

latents_mean = (
Expand All @@ -424,7 +432,11 @@ def prepare_latents(
latent_condition = (latent_condition - latents_mean) * latents_std

mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, num_frames))] = 0

if last_image is None:
mask_lat_size[:, :, list(range(1, num_frames))] = 0
else:
mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
Expand Down Expand Up @@ -476,6 +488,7 @@ def __call__(
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
image_embeds: Optional[torch.Tensor] = None,
last_image: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -620,7 +633,10 @@ def __call__(
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)

if image_embeds is None:
image_embeds = self.encode_image(image, device)
if last_image is None:
image_embeds = self.encode_image(image, device)
else:
image_embeds = self.encode_image([image, last_image], device)
image_embeds = image_embeds.repeat(batch_size, 1, 1)
image_embeds = image_embeds.to(transformer_dtype)

Expand All @@ -631,6 +647,10 @@ def __call__(
# 5. Prepare latent variables
num_channels_latents = self.vae.config.z_dim
image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
if last_image is not None:
last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
device, dtype=torch.float32
)
latents, condition = self.prepare_latents(
image,
batch_size * num_videos_per_prompt,
Expand All @@ -642,6 +662,7 @@ def __call__(
device,
generator,
latents,
last_image,
)

# 6. Denoising loop
Expand Down
Loading