Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ package for LLMs with MLX.
- Image classification using [ResNets on CIFAR-10](cifar).
- Convolutional variational autoencoder [(CVAE) on MNIST](cvae).

### Video Models

- Text-to-video and image-to-video generation with [Wan2.1](video/wan2.1).

### Audio Models

- Speech recognition with [OpenAI's Whisper](whisper).
Expand Down
154 changes: 154 additions & 0 deletions video/wan2.1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
Wan2.1
======

Wan2.1 text-to-video and image-to-video implementation in MLX. The model
weights are downloaded directly from the [Hugging Face
Hub](https://huggingface.co/Wan-AI).

| Model | Task | HF Repo | RAM (unquantized), 81 frames | Single DiT step on M4 Max chip, 81 frames |
|-------|------|---------|-----------------|---|
| 1.3B | T2V | [Wan-AI/Wan2.1-T2V-1.3B](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) | ~10GB | ~100 s/it |
| 14B | T2V | [Wan-AI/Wan2.1-T2V-14B](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) | ~36GB | ~230 s/it |
| 14B | I2V | [Wan-AI/Wan2.1-I2V-14B-480P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) | ~39GB | ~250 s/it |

| T2V 1.3B | T2V 14B | I2V 14B |
|---|---|---|
| ![WAN t2v 1.3B](static/out_t2v_1_3b.gif) |![WAN t2v 14B distilled](static/out_t2v_cats.gif) | ![WAN t2v 14B distilled](static/out_i2v_astronaut.gif) |
| Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage. | Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage. | An astronaut riding a horse |

Installation
------------

Install the dependencies:

pip install -r requirements.txt

> [!Note]
> Saving videos requires [ffmpeg](https://ffmpeg.org/) on your PATH.

Usage
-----

### Text-to-Video

Generate a video with the default 1.3B model:

```shell
python txt2video.py 'A cat playing piano' --output out.mp4
```

Use the 14B model with quantization:

```shell
python txt2video.py 'A cat playing piano' \
--model t2v-14B --quantize --output out_14B.mp4
```

Adjust resolution, frame count, and sampling parameters:

```shell
python txt2video.py 'Ocean waves crashing on a rocky shore at sunset' \
--size 832x480 --frames 81 --steps 50 --guidance 5.0 --seed 42 \
--output waves.mp4
```

For more parameters, use `python txt2video.py --help`.

### Image-to-Video

Generate a video from an input image:

```shell
python img2video.py 'Astronaut riding a horse' \
--image ./inputs/astronaut-on-a-horse.png --quantize --output out_i2v.mp4
```

Adjust resolution and sampling parameters:

```shell
python img2video.py 'Astronaut riding a horse' \
--image ./inputs/astronaut-on-a-horse.png --size 832x480 --frames 81 --steps 40 \
--guidance 5.0 --shift 3.0 --seed 42 --output out_i2v.mp4
```

For more parameters, use `python img2video.py --help`.

### Quantization

Pass `--quantize` (or `-q`) to the CLI

```shell
python txt2video.py 'A cat playing piano' --quantize --output out_quantized.mp4
```

### Disabling the cache
To get additional memory savings at the expense of a bit of speed use `--no-cache` argument. It will prevent MLX from utilizing the cache (sets `mx.set_cache_limit(0)` under the hood). See [documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.set_cache_limit.html) for more info
```shell
python txt2video.py 'A cat playing piano' --output out.mp4 --no-cache
```

For 1.3B model 480p 81 frames `--no-cache` run utilizes ~10GB of RAM and ~14GB of RAM otherwise

### Custom DiT Weights

Use `--checkpoint` to load custom DiT weights (e.g. [step-distilled models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)).
Pass `--sampler euler` to use Euler sampling for step-distilled models:

For text to video pipeline you can try [this 4 steps distilled model](https://huggingface.co/lightx2v/Wan2.1-Distill-Models/blob/main/wan2.1_t2v_14b_lightx2v_4step.safetensors)

```shell
wget https://huggingface.co/lightx2v/Wan2.1-Distill-Models/blob/main/wan2.1_t2v_14b_lightx2v_4step.safetensors
```

```shell
python txt2video.py 'A cat playing piano' \
--model t2v-14B --checkpoint ./wan2.1_t2v_14b_lightx2v_4step.safetensors \
--sampler euler --steps 4 --guidance 1.0 \
--quantize --output out_t2v_distilled.mp4
```

For image to video pipeline we use [4 steps distilled i2v model](https://huggingface.co/lightx2v/Wan2.1-Distill-Models/resolve/main/wan2.1_i2v_480p_lightx2v_4step.safetensors)

```shell
wget https://huggingface.co/lightx2v/Wan2.1-Distill-Models/resolve/main/wan2.1_i2v_480p_lightx2v_4step.safetensors
```

```shell
python img2video.py 'Astronaut riding a horse' \
--image ./inputs/astronaut-on-a-horse.png --checkpoint ./wan2.1_i2v_480p_lightx2v_4step.safetensors \
--sampler euler --steps 4 --guidance 1.0 --shift 5.0 \
--quantize --output out_i2v_distilled.mp4
```

### Options

- **Negative prompts**: `--n-prompt 'blurry, low quality, distorted'`
- **Disable CFG**: `--guidance 1.0` skips the unconditional pass, roughly
halving compute per step.

### TeaCache

[TeaCache](https://arxiv.org/abs/2411.19108) skips redundant transformer computations when consecutive steps
produce similar embeddings, eliminating 20-60% of forward passes. Note that the TeaCache parameters are calibrated for each resolution, consult with [LightX2V](https://github.com/ModelTC/LightX2V/tree/main/configs/caching) configs for advanced tweaking. Our defaults are located at [pipeline.py](./wan/pipeline.py#20)

```shell
python txt2video.py 'A cat playing piano' --teacache 0.05 --output out.mp4 --verbose
```

Recommended thresholds (1.3B):

| Threshold | Skip Rate | Quality |
|-----------|-----------|---------|
| `0.05` | ~34% | Almost lossless |
| `0.1` | ~58% | Slightly corrupted |
| `0.25` | ~76% | Visible quality loss |

#### Result with --teacache for 1.3B model
`Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.`
|`--teacache 0.05`, 34% steps skipped (17/50) |`--teacache 0.1`, 58% steps skipped (29/50) |`--teacache 0.25`, 76% steps skipped (38/50) |
|---|---|---|
|![WAN t2v 1.3B teacache=0.05](static/out_t2v_1_3b_teacache_005.gif)|![WAN t2v 1.3B teacache=0.05](static/out_t2v_1_3b_teacache_01.gif)|![WAN t2v 1.3B teacache=0.05](static/out_t2v_1_3b_teacache_025.gif)|

# References
1. [Original WAN 2.1 implemetation](https://github.com/Wan-Video/Wan2.1)
2. [LightX2V](https://github.com/ModelTC/LightX2V)
168 changes: 168 additions & 0 deletions video/wan2.1/img2video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright © 2026 Apple Inc.

"""Generate videos from an image and text prompt using Wan2.1 I2V."""

import argparse
import logging

import mlx.core as mx
import mlx.nn as nn
from tqdm import tqdm
from wan import WanPipeline
from wan.utils import save_video


def quantization_predicate(name, m):
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate videos from an image and text prompt using Wan2.1 I2V"
)
parser.add_argument("prompt")
parser.add_argument("--image", required=True, help="Path to input image")
parser.add_argument("--model", choices=["i2v-14B"], default="i2v-14B")
parser.add_argument(
"--size",
type=lambda x: tuple(map(int, x.split("x"))),
default=(832, 480),
help="Video size as WxH (default: 832x480)",
)
parser.add_argument("--frames", type=int, default=81)
parser.add_argument(
"--steps", type=int, default=40, help="Number of denoising steps"
)
parser.add_argument("--guidance", type=float, default=5.0)
parser.add_argument("--shift", type=float, default=3.0)
parser.add_argument("--seed", type=int)
parser.add_argument(
"--quantize",
"-q",
type=int,
nargs="?",
const=8,
default=0,
choices=[0, 4, 8],
metavar="{4,8}",
help="Quantize DiT weights (default: 8-bit when flag used without value)",
)
parser.add_argument(
"--n-prompt",
default="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as in txt2img.py .

)
parser.add_argument(
"--teacache",
type=float,
default=0.0,
help="TeaCache threshold for step skipping (0=off, 0.26=recommended for i2v)",
)
parser.add_argument(
"--checkpoint",
type=str,
default=None,
help="Path to custom DiT weights (.safetensors), e.g. distilled models",
)
parser.add_argument(
"--sampler",
choices=["unipc", "euler"],
default="unipc",
help="Sampler: unipc (default) or euler (for step-distilled models)",
)
parser.add_argument("--output", default="out.mp4")
parser.add_argument("--preload-models", action="store_true")
parser.add_argument(
"--compile-vae", action="store_true", help="Compile VAE decoder"
)
parser.add_argument(
"--no-cache",
action="store_true",
help="Disable Metal buffer cache (mx.set_cache_limit(0)) to reduce swap pressure",
)
parser.add_argument("--verbose", "-v", action="store_true")
args = parser.parse_args()

if args.sampler == "euler":
# Evenly spaced steps: e.g. 4 steps -> [1000, 750, 500, 250]
n = args.steps
denoising_step_list = [1000 * i // n for i in range(n, 0, -1)]
else:
denoising_step_list = None

mx.set_default_device(mx.gpu)
if args.no_cache:
mx.set_cache_limit(0)

if args.verbose:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(message)s"))
logging.getLogger("wan").setLevel(logging.INFO)
logging.getLogger("wan").addHandler(handler)

# Load pipeline
pipeline = WanPipeline(args.model, checkpoint=args.checkpoint)

# Quantize DiT
if args.quantize:
nn.quantize(
pipeline.flow, bits=args.quantize, class_predicate=quantization_predicate
)
print(f"Quantized DiT to {args.quantize}-bit")

if args.preload_models:
pipeline.ensure_models_are_loaded()

# Generate latents (generator pattern)
latents = pipeline.generate_latents(
args.prompt,
image_path=args.image,
negative_prompt=args.n_prompt,
size=args.size,
frame_num=args.frames,
num_steps=args.steps,
guidance=args.guidance,
shift=args.shift,
seed=args.seed,
teacache=args.teacache,
verbose=args.verbose,
denoising_step_list=denoising_step_list,
)

# 1. Conditioning
conditioning = next(latents)
mx.eval(conditioning)
peak_mem_conditioning = mx.get_peak_memory() / 1024**3
mx.reset_peak_memory()

# Free T5 and CLIP memory
del pipeline.t5
if pipeline.clip is not None:
del pipeline.clip
mx.clear_cache()

# 2. Denoising loop
for x_t in tqdm(latents, total=args.steps):
mx.eval(x_t)

# Free DiT memory
del pipeline.flow
mx.clear_cache()
peak_mem_generation = mx.get_peak_memory() / 1024**3
mx.reset_peak_memory()

# 3. VAE decode
video = pipeline.decode(x_t, compile_vae=args.compile_vae)
mx.eval(video)
peak_mem_decoding = mx.get_peak_memory() / 1024**3

# Save video
save_video(video, args.output)

if args.verbose:
peak_mem_overall = max(
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
)
print(f"Peak memory conditioning: {peak_mem_conditioning:.3f}GB")
print(f"Peak memory generation: {peak_mem_generation:.3f}GB")
print(f"Peak memory decoding: {peak_mem_decoding:.3f}GB")
print(f"Peak memory overall: {peak_mem_overall:.3f}GB")
Binary file added video/wan2.1/inputs/astronaut-on-a-horse.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions video/wan2.1/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
einops>=0.8.2 # for mlx compatible einops
huggingface_hub
mlx>=0.31.0 # for conv3d memory and speed fix
numpy
Pillow
tokenizers
torch # for loading of huggingface weights
tqdm
Binary file added video/wan2.1/static/out_i2v_astronaut.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added video/wan2.1/static/out_t2v_1_3b.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added video/wan2.1/static/out_t2v_1_3b_teacache_005.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added video/wan2.1/static/out_t2v_1_3b_teacache_01.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added video/wan2.1/static/out_t2v_1_3b_teacache_025.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added video/wan2.1/static/out_t2v_cats.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading