Skip to content

Commit bf2c6e0

Browse files
committed
add
1 parent 95a55f9 commit bf2c6e0

File tree

4 files changed

+193
-103
lines changed

4 files changed

+193
-103
lines changed

scripts/convert_wan_to_diffusers.py

Lines changed: 117 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,27 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
320320
}
321321
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
322322
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
323-
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
323+
elif model_type == "Wan2.2-TI2V-5B":
324+
config = {
325+
"model_id": "Wan-AI/Wan2.2-TI2V-5B",
326+
"diffusers_config": {
327+
"added_kv_proj_dim": None,
328+
"attention_head_dim": 128,
329+
"cross_attn_norm": True,
330+
"eps": 1e-06,
331+
"ffn_dim": 14336,
332+
"freq_dim": 256,
333+
"in_channels": 48,
334+
"num_attention_heads": 24,
335+
"num_layers": 30,
336+
"out_channels": 48,
337+
"patch_size": [1, 2, 2],
338+
"qk_norm": "rms_norm_across_heads",
339+
"text_dim": 4096,
340+
},
341+
}
342+
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
343+
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
324344
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
325345

326346

@@ -567,106 +587,110 @@ def convert_vae():
567587
"in_channels": 12,
568588
"out_channels": 12,
569589
"decoder_base_dim": 256,
590+
"scale_factor_temporal": 4,
591+
"scale_factor_spatial": 16,
592+
"patch_size": 2,
570593
"latents_mean":[
571-
-0.2289,
572-
-0.0052,
573-
-0.1323,
574-
-0.2339,
594+
-0.2289,
595+
-0.0052,
596+
-0.1323,
597+
-0.2339,
575598
-0.2799,
576-
-0.0174,
577-
-0.1838,
578-
-0.1557,
599+
0.0174,
600+
0.1838,
601+
0.1557,
579602
-0.1382,
580-
-0.0542,
581-
-0.2813,
582-
-0.0891,
583-
-0.1570,
603+
0.0542,
604+
0.2813,
605+
0.0891,
606+
0.1570,
584607
-0.0098,
585-
-0.0375,
608+
0.0375,
586609
-0.1825,
587610
-0.2246,
588611
-0.1207,
589612
-0.0698,
590-
-0.5109,
591-
-0.2665,
613+
0.5109,
614+
0.2665,
592615
-0.2108,
593616
-0.2158,
594-
-0.2502,
617+
0.2502,
595618
-0.2055,
596619
-0.0322,
597-
-0.1109,
598-
-0.1567,
620+
0.1109,
621+
0.1567,
599622
-0.0729,
600-
-0.0899,
623+
0.0899,
601624
-0.2799,
602625
-0.1230,
603626
-0.0313,
604627
-0.1649,
605-
-0.0117,
606-
-0.0723,
628+
0.0117,
629+
0.0723,
607630
-0.2839,
608631
-0.2083,
609632
-0.0520,
610-
-0.3748,
611-
-0.0152,
612-
-0.1957,
613-
-0.1433,
633+
0.3748,
634+
0.0152,
635+
0.1957,
636+
0.1433,
614637
-0.2944,
615-
-0.3573,
638+
0.3573,
616639
-0.0548,
617640
-0.1681,
618641
-0.0667,
619642
],
620-
"latents_std":[
621-
-0.4765,
622-
-1.0364,
623-
-0.4514,
624-
-1.1677,
625-
-0.5313,
626-
-0.4990,
627-
-0.4818,
628-
-0.5013,
629-
-0.8158,
630-
-1.0344,
631-
-0.5894,
632-
-1.0901,
633-
-0.6885,
634-
-0.6165,
635-
-0.8454,
636-
-0.4978,
637-
-0.5759,
638-
-0.3523,
639-
-0.7135,
640-
-0.6804,
641-
-0.5833,
642-
-1.4146,
643-
-0.8986,
644-
-0.5659,
645-
-0.7069,
646-
-0.5338,
647-
-0.4889,
648-
-0.4917,
649-
-0.4069,
650-
-0.4999,
651-
-0.6866,
652-
-0.4093,
653-
-0.5709,
654-
-0.6065,
655-
-0.6415,
656-
-0.4944,
657-
-0.5726,
658-
-1.2042,
659-
-0.5458,
660-
-1.6887,
661-
-0.3971,
662-
-1.0600,
663-
-0.3943,
664-
-0.5537,
665-
-0.5444,
666-
-0.4089,
667-
-0.7468,
668-
-0.7744,
643+
"latents_std": [
644+
0.4765,
645+
1.0364,
646+
0.4514,
647+
1.1677,
648+
0.5313,
649+
0.4990,
650+
0.4818,
651+
0.5013,
652+
0.8158,
653+
1.0344,
654+
0.5894,
655+
1.0901,
656+
0.6885,
657+
0.6165,
658+
0.8454,
659+
0.4978,
660+
0.5759,
661+
0.3523,
662+
0.7135,
663+
0.6804,
664+
0.5833,
665+
1.4146,
666+
0.8986,
667+
0.5659,
668+
0.7069,
669+
0.5338,
670+
0.4889,
671+
0.4917,
672+
0.4069,
673+
0.4999,
674+
0.6866,
675+
0.4093,
676+
0.5709,
677+
0.6065,
678+
0.6415,
679+
0.4944,
680+
0.5726,
681+
1.2042,
682+
0.5458,
683+
1.6887,
684+
0.3971,
685+
1.0600,
686+
0.3943,
687+
0.5537,
688+
0.5444,
689+
0.4089,
690+
0.7468,
691+
0.7744,
669692
],
693+
"clip_output": False,
670694
}
671695

672696

@@ -855,7 +879,7 @@ def convert_vae_22():
855879
new_state_dict[key] = value
856880

857881
with init_empty_weights():
858-
vae = AutoencoderKLWan(**vae22_config)
882+
vae = AutoencoderKLWan(**vae22_diffusers_config)
859883
vae.load_state_dict(new_state_dict, strict=True, assign=True)
860884
return vae
861885

@@ -878,7 +902,7 @@ def get_args():
878902
if __name__ == "__main__":
879903
args = get_args()
880904

881-
if "Wan2.2" in args.model_type:
905+
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type:
882906
transformer = convert_transformer(args.model_type, stage="high_noise_model")
883907
transformer_2 = convert_transformer(args.model_type, stage="low_noise_model")
884908
else:
@@ -892,7 +916,12 @@ def get_args():
892916

893917
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16)
894918
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
895-
flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0
919+
if "FLF2V" in args.model_type:
920+
flow_shift = 16.0
921+
elif "TI2V" in args.model_type:
922+
flow_shift = 5.0
923+
else:
924+
flow_shift = 3.0
896925
scheduler = UniPCMultistepScheduler(
897926
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
898927
)
@@ -902,7 +931,7 @@ def get_args():
902931
dtype = DTYPE_MAPPING[args.dtype]
903932
transformer.to(dtype)
904933

905-
if "Wan2.2" and "I2V" in args.model_type:
934+
if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type:
906935
pipe = WanImageToVideoPipeline(
907936
transformer=transformer,
908937
transformer_2=transformer_2,
@@ -922,6 +951,15 @@ def get_args():
922951
scheduler=scheduler,
923952
boundary_ratio=0.875,
924953
)
954+
elif "Wan2.2" and "TI2V" in args.model_type:
955+
pipe = WanPipeline(
956+
transformer=transformer,
957+
text_encoder=text_encoder,
958+
tokenizer=tokenizer,
959+
vae=vae,
960+
scheduler=scheduler,
961+
expand_timesteps=True,
962+
)
925963
elif "I2V" in args.model_type or "FLF2V" in args.model_type:
926964
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
927965
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,9 @@ def __init__(
10121012
in_channels: int = 3,
10131013
out_channels: int = 3,
10141014
patch_size: Optional[int] = None,
1015+
scale_factor_temporal: Optional[int] = 4,
1016+
scale_factor_spatial: Optional[int] = 8,
1017+
clip_output: bool = True,
10151018
) -> None:
10161019
super().__init__()
10171020

@@ -1193,7 +1196,8 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True):
11931196
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
11941197
out = torch.cat([out, out_], 2)
11951198

1196-
out = torch.clamp(out, min=-1.0, max=1.0)
1199+
if self.config.clip_output:
1200+
out = torch.clamp(out, min=-1.0, max=1.0)
11971201
if self.config.patch_size is not None:
11981202
out = unpatchify(out, patch_size=self.config.patch_size)
11991203
self.clear_cache()

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,11 @@ def forward(
170170
timestep: torch.Tensor,
171171
encoder_hidden_states: torch.Tensor,
172172
encoder_hidden_states_image: Optional[torch.Tensor] = None,
173-
):
173+
timestep_seq_len: Optional[int] = None,
174+
):
174175
timestep = self.timesteps_proj(timestep)
176+
if timestep_seq_len is not None:
177+
timestep = timestep.unflatten(0, (1, timestep_seq_len))
175178

176179
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
177180
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
@@ -309,9 +312,24 @@ def forward(
309312
temb: torch.Tensor,
310313
rotary_emb: torch.Tensor,
311314
) -> torch.Tensor:
312-
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
313-
self.scale_shift_table + temb.float()
314-
).chunk(6, dim=1)
315+
316+
if temb.ndim == 4:
317+
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
318+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
319+
self.scale_shift_table.unsqueeze(0) + temb.float()
320+
).chunk(6, dim=2)
321+
# batch_size, seq_len, 1, inner_dim
322+
shift_msa = shift_msa.squeeze(2)
323+
scale_msa = scale_msa.squeeze(2)
324+
gate_msa = gate_msa.squeeze(2)
325+
c_shift_msa = c_shift_msa.squeeze(2)
326+
c_scale_msa = c_scale_msa.squeeze(2)
327+
c_gate_msa = c_gate_msa.squeeze(2)
328+
else:
329+
# temb: batch_size, 6, inner_dim
330+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
331+
self.scale_shift_table + temb.float()
332+
).chunk(6, dim=1)
315333

316334
# 1. Self-attention
317335
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
@@ -469,10 +487,22 @@ def forward(
469487
hidden_states = self.patch_embedding(hidden_states)
470488
hidden_states = hidden_states.flatten(2).transpose(1, 2)
471489

490+
# timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
491+
if timestep.ndim == 2:
492+
ts_seq_len = timestep.shape[1]
493+
timestep = timestep.flatten() # batch_size * seq_len
494+
else:
495+
ts_seq_len = None
496+
472497
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
473-
timestep, encoder_hidden_states, encoder_hidden_states_image
498+
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
474499
)
475-
timestep_proj = timestep_proj.unflatten(1, (6, -1))
500+
if ts_seq_len is not None:
501+
# batch_size, seq_len, 6, inner_dim
502+
timestep_proj = timestep_proj.unflatten(2, (6, -1))
503+
else:
504+
# batch_size, 6, inner_dim
505+
timestep_proj = timestep_proj.unflatten(1, (6, -1))
476506

477507
if encoder_hidden_states_image is not None:
478508
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
@@ -488,7 +518,14 @@ def forward(
488518
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
489519

490520
# 5. Output norm, projection & unpatchify
491-
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
521+
if temb.ndim ==3:
522+
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
523+
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
524+
shift = shift.squeeze(2)
525+
scale = scale.squeeze(2)
526+
else:
527+
# batch_size, inner_dim
528+
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
492529

493530
# Move the shift and scale tensors to the same device as hidden_states.
494531
# When using multi-GPU inference via accelerate these will be on the

0 commit comments

Comments
 (0)