|
6 | 6 | from accelerate import init_empty_weights |
7 | 7 | from huggingface_hub import snapshot_download, hf_hub_download |
8 | 8 | from safetensors.torch import load_file |
9 | | -from transformers import UMT5EncoderModel, AutoTokenizer |
| 9 | +from transformers import UMT5EncoderModel, AutoTokenizer, CLIPVisionModelWithProjection, AutoProcessor |
10 | 10 |
|
11 | | -from diffusers import WanTransformer3DModel, FlowMatchEulerDiscreteScheduler, WanPipeline, WanImageToVideoPipeline |
| 11 | +from diffusers import WanTransformer3DModel, FlowMatchEulerDiscreteScheduler, WanPipeline, WanImageToVideoPipeline, AutoencoderKLWan |
12 | 12 |
|
13 | 13 |
|
14 | 14 | TRANSFORMER_KEYS_RENAME_DICT = { |
@@ -357,7 +357,10 @@ def convert_vae(): |
357 | 357 | # Keep other keys unchanged |
358 | 358 | new_state_dict[key] = value |
359 | 359 |
|
360 | | - return new_state_dict |
| 360 | + with init_empty_weights(): |
| 361 | + vae = AutoencoderKLWan() |
| 362 | + vae.load_state_dict(new_state_dict, strict=True, assign=True) |
| 363 | + return vae |
361 | 364 |
|
362 | 365 |
|
363 | 366 | def get_args(): |
@@ -388,15 +391,24 @@ def get_args(): |
388 | 391 | scheduler = FlowMatchEulerDiscreteScheduler(shift=3.0) |
389 | 392 |
|
390 | 393 | if "I2V" in args.model_type: |
391 | | - pipeline_cls = WanImageToVideoPipeline |
| 394 | + image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16) |
| 395 | + image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") |
| 396 | + pipe = WanImageToVideoPipeline( |
| 397 | + transformer=transformer, |
| 398 | + text_encoder=text_encoder, |
| 399 | + tokenizer=tokenizer, |
| 400 | + vae=vae, |
| 401 | + scheduler=scheduler, |
| 402 | + image_encoder=image_encoder, |
| 403 | + image_processor=image_processor, |
| 404 | + ) |
392 | 405 | else: |
393 | | - pipeline_cls = WanPipeline |
394 | | - |
395 | | - pipe = pipeline_cls( |
396 | | - transformer=transformer, |
397 | | - text_encoder=text_encoder, |
398 | | - tokenizer=tokenizer, |
399 | | - vae=vae, |
400 | | - scheduler=scheduler, |
401 | | - ) |
| 406 | + pipe = WanPipeline( |
| 407 | + transformer=transformer, |
| 408 | + text_encoder=text_encoder, |
| 409 | + tokenizer=tokenizer, |
| 410 | + vae=vae, |
| 411 | + scheduler=scheduler, |
| 412 | + ) |
| 413 | + |
402 | 414 | pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") |
0 commit comments