|
39 | 39 | "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", |
40 | 40 | "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", |
41 | 41 | "img_emb.proj.4": "condition_embedder.image_embedder.norm2", |
| 42 | + # for the FLF2V model |
| 43 | + "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed", |
| 44 | + # Add attention component mappings |
| 45 | + "self_attn.q": "attn1.to_q", |
| 46 | + "self_attn.k": "attn1.to_k", |
| 47 | + "self_attn.v": "attn1.to_v", |
| 48 | + "self_attn.o": "attn1.to_out.0", |
| 49 | + "self_attn.norm_q": "attn1.norm_q", |
| 50 | + "self_attn.norm_k": "attn1.norm_k", |
| 51 | + "cross_attn.q": "attn2.to_q", |
| 52 | + "cross_attn.k": "attn2.to_k", |
| 53 | + "cross_attn.v": "attn2.to_v", |
| 54 | + "cross_attn.o": "attn2.to_out.0", |
| 55 | + "cross_attn.norm_q": "attn2.norm_q", |
| 56 | + "cross_attn.norm_k": "attn2.norm_k", |
| 57 | + "attn2.to_k_img": "attn2.add_k_proj", |
| 58 | + "attn2.to_v_img": "attn2.add_v_proj", |
| 59 | + "attn2.norm_k_img": "attn2.norm_added_k", |
42 | 60 | } |
43 | 61 |
|
44 | 62 | TRANSFORMER_SPECIAL_KEYS_REMAP = {} |
@@ -135,6 +153,28 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: |
135 | 153 | "text_dim": 4096, |
136 | 154 | }, |
137 | 155 | } |
| 156 | + elif model_type == "Wan-FLF2V-14B-720P": |
| 157 | + config = { |
| 158 | + "model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder |
| 159 | + "diffusers_config": { |
| 160 | + "image_dim": 1280, |
| 161 | + "added_kv_proj_dim": 5120, |
| 162 | + "attention_head_dim": 128, |
| 163 | + "cross_attn_norm": True, |
| 164 | + "eps": 1e-06, |
| 165 | + "ffn_dim": 13824, |
| 166 | + "freq_dim": 256, |
| 167 | + "in_channels": 36, |
| 168 | + "num_attention_heads": 40, |
| 169 | + "num_layers": 40, |
| 170 | + "out_channels": 16, |
| 171 | + "patch_size": [1, 2, 2], |
| 172 | + "qk_norm": "rms_norm_across_heads", |
| 173 | + "text_dim": 4096, |
| 174 | + "rope_max_seq_len": 1024, |
| 175 | + "pos_embed_seq_len": 257 * 2, |
| 176 | + }, |
| 177 | + } |
138 | 178 | return config |
139 | 179 |
|
140 | 180 |
|
@@ -393,11 +433,12 @@ def get_args(): |
393 | 433 | vae = convert_vae() |
394 | 434 | text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") |
395 | 435 | tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") |
| 436 | + flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0 |
396 | 437 | scheduler = UniPCMultistepScheduler( |
397 | | - prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=3.0 |
| 438 | + prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift |
398 | 439 | ) |
399 | 440 |
|
400 | | - if "I2V" in args.model_type: |
| 441 | + if "I2V" in args.model_type or "FLF2V" in args.model_type: |
401 | 442 | image_encoder = CLIPVisionModelWithProjection.from_pretrained( |
402 | 443 | "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16 |
403 | 444 | ) |
|
0 commit comments