Skip to content

Add wan 2.1 model#1409

Open
belkakari wants to merge 8 commits intoml-explore:mainfrom
belkakari:wan-2.1
Open

Add wan 2.1 model#1409
belkakari wants to merge 8 commits intoml-explore:mainfrom
belkakari:wan-2.1

Conversation

@belkakari
Copy link
Copy Markdown

@belkakari belkakari commented Mar 11, 2026

This PR adds a WAN 2.1 text2video and image2video model support with optimizations like TeaCache and step distilled models support. Based on original WAN 2.1 implementation and LightX2V

Basic commands:

1.3B text-to-video

python txt2video.py 'A cat playing piano' --output out.mp4

14B text-to-video

python txt2video.py 'A cat playing piano' --model t2v-14B --quantize --output out_14B.mp4

14B image-to-video

python img2video.py 'Astronaut riding a horse' \
   --image ./inputs/astronaut-on-a-horse.png --quantize --output out_i2v.mp4

Step distilled models:
T2V

wget https://huggingface.co/lightx2v/Wan2.1-Distill-Models/blob/main/wan2.1_t2v_14b_lightx2v_4step.safetensors
python txt2video.py 'A cat playing piano' \
    --model t2v-14B --checkpoint ./wan2.1_t2v_14b_lightx2v_4step.safetensors \
    --sampler euler --steps 4 --guidance 1.0 \
    --quantize --output out_t2v_distilled.mp4

I2V

wget https://huggingface.co/lightx2v/Wan2.1-Distill-Models/resolve/main/wan2.1_i2v_480p_lightx2v_4step.safetensors
python img2video.py 'Astronaut riding a horse' \
    --image ./inputs/astronaut-on-a-horse.png --checkpoint ./wan2.1_i2v_480p_lightx2v_4step.safetensors \
    --sampler euler --steps 4 --guidance 1.0 --shift 5.0 \
    --quantize --output out_i2v_distilled.mp4

Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @belkakari !

I left quite a few comments, a lot of them stylistic and some of them performance related.

One more, general comment is that the weights are in float32. That seems unnecessary and will impact performance as well as memory. It also hides possible upcasting if someone wants to have the computation happen in bf16 or another dtype which will be much faster on M5s for instance.

def _project_time_fn(e, w, b):
x = nn.silu(e)
x = mx.matmul(x, w.T) + b
return x
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost certainly no need to compile the above. Using nn.Linear layers would be more understandable and the same speed. It would also use mx.addmm which fuses the x @ w.T + b into one op. For sinusoidal embedding you can use nn.SinusoidalPositionalEncoding. You can wrap them in lists or nn.Sequential like their PyTorch counterparts. It would also help with quantization as they would be automatically quantized which won't quite happen now (unless you implement a custom to_quantized function).

self.patch_embedding_weight = mx.random.normal((dim, *patch_size, in_dim)) * (
1.0 / (in_dim * math.prod(patch_size)) ** 0.5
)
self.patch_embedding_bias = mx.zeros((dim,))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same goes for here, why not nn.Conv3d ?

eps,
cross_attn_type=model_type,
)
setattr(self, f"block_{i}", block)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not a list ? This just makes your life hard when you want to iterate over them which you always will want to.

value = mx.transpose(value, (0, 2, 3, 4, 1))

# blocks.N -> block_N
new_key = re.sub(r"blocks\.(\d+)\.", r"block_\1.", new_key)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of these are not needed when the blocks are put in a list and the ffn in a sequential (or a list) etc.


# Merge separate Q/K/V into QKV for self-attention,
# and K/V into KV for cross-attention
remapped = WanModel._merge_qkv_weights(remapped)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is good especially for distributed inference later but generally it doesn't provide much of a speedup. The q, k, v projections will happen in parallel in the GPU anyway.

Just to be clear, a good optimization but probably not the lowest hanging fruit.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lowest hanging fruit is probably updating the modulation parameter in the layernorm to contain the 1 + so that the layernorm can just use self.modulation + e directly see comments in layers.py .

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, added modulation as you've suggested

)
parser.add_argument(
"--n-prompt",
default="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as in txt2img.py .

# Register all layers via setattr
for stage_idx, stage in enumerate(self.upsamples):
for layer_idx, (layer_type, layer) in enumerate(stage):
setattr(self, f"upsample_s{stage_idx}_l{layer_idx}_{layer_type}", layer)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as for model.py , no need for these to be set as attributes they can be nested lists.

Comment on lines +423 to +431
if compile and i == 1 and self._compiled_decode is None:
self._compiled_decode = mx.compile(self.decoder._forward_functional)

if self._compiled_decode is not None:
out_frame, feat_cache = self._compiled_decode(frame, feat_cache)
else:
out_frame, feat_cache = self.decoder._forward_functional(
frame, feat_cache
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not always compile?

Comment on lines +420 to +421
for i in range(num_frames):
frame = x[:, i : i + 1, :, :, :]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not batch it? For memory saving?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if the video is long enough it won't fit into the memory. We can have configurable batch size as an additional parameter with default=1, wdyt?

Comment on lines +318 to +324
scale = 1.0 / dim**0.5
self.to_qkv_weight = mx.random.uniform(
low=-scale, high=scale, shape=(dim * 3, 1, 1, dim)
)
self.to_qkv_bias = mx.zeros((dim * 3,))
self.proj_weight = mx.zeros((dim, 1, 1, dim))
self.proj_bias = mx.zeros((dim,))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before, should just be linear layers.

@belkakari belkakari force-pushed the wan-2.1 branch 3 times, most recently from 2a3ffb5 to e4cd847 Compare March 24, 2026 13:30
@belkakari belkakari requested a review from angeloskath March 24, 2026 13:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants