-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add wan 2.1 model #1409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
belkakari
wants to merge
8
commits into
ml-explore:main
Choose a base branch
from
belkakari:wan-2.1
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add wan 2.1 model #1409
Changes from 5 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
10dca41
initial wan2.1 implementation
2f57377
add image2video, distilled models and teacache support
b303a7b
image2video an text2video fixes, readme update
fa9e5eb
reduce the amount of reshapes, make everything channel-last, changed …
ddd567c
remove redundant code, fix documentation
89668b9
PR fixes, pt1
1953e57
PR review pt.2
9ef29a1
readme fix
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| Wan2.1 | ||
| ====== | ||
|
|
||
| Wan2.1 text-to-video and image-to-video implementation in MLX. The model | ||
| weights are downloaded directly from the [Hugging Face | ||
| Hub](https://huggingface.co/Wan-AI). | ||
|
|
||
| | Model | Task | HF Repo | RAM (unquantized), 81 frames | Single DiT step on M4 Max chip, 81 frames | | ||
| |-------|------|---------|-----------------|---| | ||
| | 1.3B | T2V | [Wan-AI/Wan2.1-T2V-1.3B](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) | ~10GB | ~100 s/it | | ||
| | 14B | T2V | [Wan-AI/Wan2.1-T2V-14B](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) | ~36GB | ~230 s/it | | ||
| | 14B | I2V | [Wan-AI/Wan2.1-I2V-14B-480P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) | ~39GB | ~250 s/it | | ||
|
|
||
| | T2V 1.3B | T2V 14B | I2V 14B | | ||
| |---|---|---| | ||
| |  | |  | | ||
| | Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage. | Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage. | An astronaut riding a horse | | ||
|
|
||
| Installation | ||
| ------------ | ||
|
|
||
| Install the dependencies: | ||
|
|
||
| pip install -r requirements.txt | ||
|
|
||
| > [!Note] | ||
| > Saving videos requires [ffmpeg](https://ffmpeg.org/) on your PATH. | ||
|
|
||
| Usage | ||
| ----- | ||
|
|
||
| ### Text-to-Video | ||
|
|
||
| Generate a video with the default 1.3B model: | ||
|
|
||
| ```shell | ||
| python txt2video.py 'A cat playing piano' --output out.mp4 | ||
| ``` | ||
|
|
||
| Use the 14B model with quantization: | ||
|
|
||
| ```shell | ||
| python txt2video.py 'A cat playing piano' \ | ||
| --model t2v-14B --quantize --output out_14B.mp4 | ||
| ``` | ||
|
|
||
| Adjust resolution, frame count, and sampling parameters: | ||
|
|
||
| ```shell | ||
| python txt2video.py 'Ocean waves crashing on a rocky shore at sunset' \ | ||
| --size 832x480 --frames 81 --steps 50 --guidance 5.0 --seed 42 \ | ||
| --output waves.mp4 | ||
| ``` | ||
|
|
||
| For more parameters, use `python txt2video.py --help`. | ||
|
|
||
| ### Image-to-Video | ||
|
|
||
| Generate a video from an input image: | ||
|
|
||
| ```shell | ||
| python img2video.py 'Astronaut riding a horse' \ | ||
| --image ./inputs/astronaut-on-a-horse.png --quantize --output out_i2v.mp4 | ||
| ``` | ||
|
|
||
| Adjust resolution and sampling parameters: | ||
|
|
||
| ```shell | ||
| python img2video.py 'Astronaut riding a horse' \ | ||
| --image ./inputs/astronaut-on-a-horse.png --size 832x480 --frames 81 --steps 40 \ | ||
| --guidance 5.0 --shift 3.0 --seed 42 --output out_i2v.mp4 | ||
| ``` | ||
|
|
||
| For more parameters, use `python img2video.py --help`. | ||
|
|
||
| ### Quantization | ||
|
|
||
| Pass `--quantize` (or `-q`) to the CLI | ||
|
|
||
| ```shell | ||
| python txt2video.py 'A cat playing piano' --quantize --output out_quantized.mp4 | ||
| ``` | ||
|
|
||
| ### Disabling the cache | ||
| To get additional memory savings at the expense of a bit of speed use `--no-cache` argument. It will prevent MLX from utilizing the cache (sets `mx.set_cache_limit(0)` under the hood). See [documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.set_cache_limit.html) for more info | ||
| ```shell | ||
| python txt2video.py 'A cat playing piano' --output out.mp4 --no-cache | ||
| ``` | ||
|
|
||
| For 1.3B model 480p 81 frames `--no-cache` run utilizes ~10GB of RAM and ~14GB of RAM otherwise | ||
|
|
||
| ### Custom DiT Weights | ||
|
|
||
| Use `--checkpoint` to load custom DiT weights (e.g. [step-distilled models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)). | ||
| Pass `--sampler euler` to use Euler sampling for step-distilled models: | ||
|
|
||
| For text to video pipeline you can try [this 4 steps distilled model](https://huggingface.co/lightx2v/Wan2.1-Distill-Models/blob/main/wan2.1_t2v_14b_lightx2v_4step.safetensors) | ||
|
|
||
| ```shell | ||
| wget https://huggingface.co/lightx2v/Wan2.1-Distill-Models/blob/main/wan2.1_t2v_14b_lightx2v_4step.safetensors | ||
| ``` | ||
|
|
||
| ```shell | ||
| python txt2video.py 'A cat playing piano' \ | ||
| --model t2v-14B --checkpoint ./wan2.1_t2v_14b_lightx2v_4step.safetensors \ | ||
| --sampler euler --steps 4 --guidance 1.0 \ | ||
| --quantize --output out_t2v_distilled.mp4 | ||
| ``` | ||
|
|
||
| For image to video pipeline we use [4 steps distilled i2v model](https://huggingface.co/lightx2v/Wan2.1-Distill-Models/resolve/main/wan2.1_i2v_480p_lightx2v_4step.safetensors) | ||
|
|
||
| ```shell | ||
| wget https://huggingface.co/lightx2v/Wan2.1-Distill-Models/resolve/main/wan2.1_i2v_480p_lightx2v_4step.safetensors | ||
| ``` | ||
|
|
||
| ```shell | ||
| python img2video.py 'Astronaut riding a horse' \ | ||
| --image ./inputs/astronaut-on-a-horse.png --checkpoint ./wan2.1_i2v_480p_lightx2v_4step.safetensors \ | ||
| --sampler euler --steps 4 --guidance 1.0 --shift 5.0 \ | ||
| --quantize --output out_i2v_distilled.mp4 | ||
| ``` | ||
|
|
||
| ### Options | ||
|
|
||
| - **Negative prompts**: `--n-prompt 'blurry, low quality, distorted'` | ||
| - **Disable CFG**: `--guidance 1.0` skips the unconditional pass, roughly | ||
| halving compute per step. | ||
|
|
||
| ### TeaCache | ||
|
|
||
| [TeaCache](https://arxiv.org/abs/2411.19108) skips redundant transformer computations when consecutive steps | ||
| produce similar embeddings, eliminating 20-60% of forward passes. Note that the TeaCache parameters are calibrated for each resolution, consult with [LightX2V](https://github.com/ModelTC/LightX2V/tree/main/configs/caching) configs for advanced tweaking. Our defaults are located at [pipeline.py](./wan/pipeline.py#20) | ||
|
|
||
| ```shell | ||
| python txt2video.py 'A cat playing piano' --teacache 0.05 --output out.mp4 --verbose | ||
| ``` | ||
|
|
||
| Recommended thresholds (1.3B): | ||
|
|
||
| | Threshold | Skip Rate | Quality | | ||
| |-----------|-----------|---------| | ||
| | `0.05` | ~34% | Almost lossless | | ||
| | `0.1` | ~58% | Slightly corrupted | | ||
| | `0.25` | ~76% | Visible quality loss | | ||
|
|
||
| #### Result with --teacache for 1.3B model | ||
| `Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.` | ||
| |`--teacache 0.05`, 34% steps skipped (17/50) |`--teacache 0.1`, 58% steps skipped (29/50) |`--teacache 0.25`, 76% steps skipped (38/50) | | ||
| |---|---|---| | ||
| |||| | ||
|
|
||
| # References | ||
| 1. [Original WAN 2.1 implemetation](https://github.com/Wan-Video/Wan2.1) | ||
| 2. [LightX2V](https://github.com/ModelTC/LightX2V) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,168 @@ | ||
| # Copyright © 2026 Apple Inc. | ||
|
|
||
| """Generate videos from an image and text prompt using Wan2.1 I2V.""" | ||
|
|
||
| import argparse | ||
| import logging | ||
|
|
||
| import mlx.core as mx | ||
| import mlx.nn as nn | ||
| from tqdm import tqdm | ||
| from wan import WanPipeline | ||
| from wan.utils import save_video | ||
|
|
||
|
|
||
| def quantization_predicate(name, m): | ||
| return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0 | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser( | ||
| description="Generate videos from an image and text prompt using Wan2.1 I2V" | ||
| ) | ||
| parser.add_argument("prompt") | ||
| parser.add_argument("--image", required=True, help="Path to input image") | ||
| parser.add_argument("--model", choices=["i2v-14B"], default="i2v-14B") | ||
| parser.add_argument( | ||
| "--size", | ||
| type=lambda x: tuple(map(int, x.split("x"))), | ||
| default=(832, 480), | ||
| help="Video size as WxH (default: 832x480)", | ||
| ) | ||
| parser.add_argument("--frames", type=int, default=81) | ||
| parser.add_argument( | ||
| "--steps", type=int, default=40, help="Number of denoising steps" | ||
| ) | ||
| parser.add_argument("--guidance", type=float, default=5.0) | ||
| parser.add_argument("--shift", type=float, default=3.0) | ||
| parser.add_argument("--seed", type=int) | ||
| parser.add_argument( | ||
| "--quantize", | ||
| "-q", | ||
| type=int, | ||
| nargs="?", | ||
| const=8, | ||
| default=0, | ||
| choices=[0, 4, 8], | ||
| metavar="{4,8}", | ||
| help="Quantize DiT weights (default: 8-bit when flag used without value)", | ||
| ) | ||
| parser.add_argument( | ||
| "--n-prompt", | ||
| default="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", | ||
| ) | ||
| parser.add_argument( | ||
| "--teacache", | ||
| type=float, | ||
| default=0.0, | ||
| help="TeaCache threshold for step skipping (0=off, 0.26=recommended for i2v)", | ||
| ) | ||
| parser.add_argument( | ||
| "--checkpoint", | ||
| type=str, | ||
| default=None, | ||
| help="Path to custom DiT weights (.safetensors), e.g. distilled models", | ||
| ) | ||
| parser.add_argument( | ||
| "--sampler", | ||
| choices=["unipc", "euler"], | ||
| default="unipc", | ||
| help="Sampler: unipc (default) or euler (for step-distilled models)", | ||
| ) | ||
| parser.add_argument("--output", default="out.mp4") | ||
| parser.add_argument("--preload-models", action="store_true") | ||
| parser.add_argument( | ||
| "--compile-vae", action="store_true", help="Compile VAE decoder" | ||
| ) | ||
| parser.add_argument( | ||
| "--no-cache", | ||
| action="store_true", | ||
| help="Disable Metal buffer cache (mx.set_cache_limit(0)) to reduce swap pressure", | ||
| ) | ||
| parser.add_argument("--verbose", "-v", action="store_true") | ||
| args = parser.parse_args() | ||
|
|
||
| if args.sampler == "euler": | ||
| # Evenly spaced steps: e.g. 4 steps -> [1000, 750, 500, 250] | ||
| n = args.steps | ||
| denoising_step_list = [1000 * i // n for i in range(n, 0, -1)] | ||
| else: | ||
| denoising_step_list = None | ||
|
|
||
| mx.set_default_device(mx.gpu) | ||
| if args.no_cache: | ||
| mx.set_cache_limit(0) | ||
|
|
||
| if args.verbose: | ||
| handler = logging.StreamHandler() | ||
| handler.setFormatter(logging.Formatter("%(message)s")) | ||
| logging.getLogger("wan").setLevel(logging.INFO) | ||
| logging.getLogger("wan").addHandler(handler) | ||
|
|
||
| # Load pipeline | ||
| pipeline = WanPipeline(args.model, checkpoint=args.checkpoint) | ||
|
|
||
| # Quantize DiT | ||
| if args.quantize: | ||
| nn.quantize( | ||
| pipeline.flow, bits=args.quantize, class_predicate=quantization_predicate | ||
| ) | ||
| print(f"Quantized DiT to {args.quantize}-bit") | ||
|
|
||
| if args.preload_models: | ||
| pipeline.ensure_models_are_loaded() | ||
|
|
||
| # Generate latents (generator pattern) | ||
| latents = pipeline.generate_latents( | ||
| args.prompt, | ||
| image_path=args.image, | ||
| negative_prompt=args.n_prompt, | ||
| size=args.size, | ||
| frame_num=args.frames, | ||
| num_steps=args.steps, | ||
| guidance=args.guidance, | ||
| shift=args.shift, | ||
| seed=args.seed, | ||
| teacache=args.teacache, | ||
| verbose=args.verbose, | ||
| denoising_step_list=denoising_step_list, | ||
| ) | ||
|
|
||
| # 1. Conditioning | ||
| conditioning = next(latents) | ||
| mx.eval(conditioning) | ||
| peak_mem_conditioning = mx.get_peak_memory() / 1024**3 | ||
| mx.reset_peak_memory() | ||
|
|
||
| # Free T5 and CLIP memory | ||
| del pipeline.t5 | ||
| if pipeline.clip is not None: | ||
| del pipeline.clip | ||
| mx.clear_cache() | ||
|
|
||
| # 2. Denoising loop | ||
| for x_t in tqdm(latents, total=args.steps): | ||
| mx.eval(x_t) | ||
|
|
||
| # Free DiT memory | ||
| del pipeline.flow | ||
| mx.clear_cache() | ||
| peak_mem_generation = mx.get_peak_memory() / 1024**3 | ||
| mx.reset_peak_memory() | ||
|
|
||
| # 3. VAE decode | ||
| video = pipeline.decode(x_t, compile_vae=args.compile_vae) | ||
| mx.eval(video) | ||
| peak_mem_decoding = mx.get_peak_memory() / 1024**3 | ||
|
|
||
| # Save video | ||
| save_video(video, args.output) | ||
|
|
||
| if args.verbose: | ||
| peak_mem_overall = max( | ||
| peak_mem_conditioning, peak_mem_generation, peak_mem_decoding | ||
| ) | ||
| print(f"Peak memory conditioning: {peak_mem_conditioning:.3f}GB") | ||
| print(f"Peak memory generation: {peak_mem_generation:.3f}GB") | ||
| print(f"Peak memory decoding: {peak_mem_decoding:.3f}GB") | ||
| print(f"Peak memory overall: {peak_mem_overall:.3f}GB") | ||
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| einops>=0.8.2 # for mlx compatible einops | ||
| huggingface_hub | ||
| mlx>=0.31.0 # for conv3d memory and speed fix | ||
| numpy | ||
| Pillow | ||
| tokenizers | ||
| torch # for loading of huggingface weights | ||
| tqdm |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as in txt2img.py .