Conversation
angeloskath
left a comment
There was a problem hiding this comment.
Great work @belkakari !
I left quite a few comments, a lot of them stylistic and some of them performance related.
One more, general comment is that the weights are in float32. That seems unnecessary and will impact performance as well as memory. It also hides possible upcasting if someone wants to have the computation happen in bf16 or another dtype which will be much faster on M5s for instance.
video/wan2.1/wan/model.py
Outdated
| def _project_time_fn(e, w, b): | ||
| x = nn.silu(e) | ||
| x = mx.matmul(x, w.T) + b | ||
| return x |
There was a problem hiding this comment.
Almost certainly no need to compile the above. Using nn.Linear layers would be more understandable and the same speed. It would also use mx.addmm which fuses the x @ w.T + b into one op. For sinusoidal embedding you can use nn.SinusoidalPositionalEncoding. You can wrap them in lists or nn.Sequential like their PyTorch counterparts. It would also help with quantization as they would be automatically quantized which won't quite happen now (unless you implement a custom to_quantized function).
video/wan2.1/wan/model.py
Outdated
| self.patch_embedding_weight = mx.random.normal((dim, *patch_size, in_dim)) * ( | ||
| 1.0 / (in_dim * math.prod(patch_size)) ** 0.5 | ||
| ) | ||
| self.patch_embedding_bias = mx.zeros((dim,)) |
There was a problem hiding this comment.
Same goes for here, why not nn.Conv3d ?
video/wan2.1/wan/model.py
Outdated
| eps, | ||
| cross_attn_type=model_type, | ||
| ) | ||
| setattr(self, f"block_{i}", block) |
There was a problem hiding this comment.
Why not a list ? This just makes your life hard when you want to iterate over them which you always will want to.
video/wan2.1/wan/model.py
Outdated
| value = mx.transpose(value, (0, 2, 3, 4, 1)) | ||
|
|
||
| # blocks.N -> block_N | ||
| new_key = re.sub(r"blocks\.(\d+)\.", r"block_\1.", new_key) |
There was a problem hiding this comment.
Most of these are not needed when the blocks are put in a list and the ffn in a sequential (or a list) etc.
|
|
||
| # Merge separate Q/K/V into QKV for self-attention, | ||
| # and K/V into KV for cross-attention | ||
| remapped = WanModel._merge_qkv_weights(remapped) |
There was a problem hiding this comment.
That is good especially for distributed inference later but generally it doesn't provide much of a speedup. The q, k, v projections will happen in parallel in the GPU anyway.
Just to be clear, a good optimization but probably not the lowest hanging fruit.
There was a problem hiding this comment.
The lowest hanging fruit is probably updating the modulation parameter in the layernorm to contain the 1 + so that the layernorm can just use self.modulation + e directly see comments in layers.py .
There was a problem hiding this comment.
Got it, added modulation as you've suggested
video/wan2.1/img2video.py
Outdated
| ) | ||
| parser.add_argument( | ||
| "--n-prompt", | ||
| default="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", |
video/wan2.1/wan/vae.py
Outdated
| # Register all layers via setattr | ||
| for stage_idx, stage in enumerate(self.upsamples): | ||
| for layer_idx, (layer_type, layer) in enumerate(stage): | ||
| setattr(self, f"upsample_s{stage_idx}_l{layer_idx}_{layer_type}", layer) |
There was a problem hiding this comment.
Same comment as for model.py , no need for these to be set as attributes they can be nested lists.
video/wan2.1/wan/vae.py
Outdated
| if compile and i == 1 and self._compiled_decode is None: | ||
| self._compiled_decode = mx.compile(self.decoder._forward_functional) | ||
|
|
||
| if self._compiled_decode is not None: | ||
| out_frame, feat_cache = self._compiled_decode(frame, feat_cache) | ||
| else: | ||
| out_frame, feat_cache = self.decoder._forward_functional( | ||
| frame, feat_cache | ||
| ) |
| for i in range(num_frames): | ||
| frame = x[:, i : i + 1, :, :, :] |
There was a problem hiding this comment.
Why not batch it? For memory saving?
There was a problem hiding this comment.
Yes, if the video is long enough it won't fit into the memory. We can have configurable batch size as an additional parameter with default=1, wdyt?
video/wan2.1/wan/vae_layers.py
Outdated
| scale = 1.0 / dim**0.5 | ||
| self.to_qkv_weight = mx.random.uniform( | ||
| low=-scale, high=scale, shape=(dim * 3, 1, 1, dim) | ||
| ) | ||
| self.to_qkv_bias = mx.zeros((dim * 3,)) | ||
| self.proj_weight = mx.zeros((dim, 1, 1, dim)) | ||
| self.proj_bias = mx.zeros((dim,)) |
There was a problem hiding this comment.
Same as before, should just be linear layers.
2a3ffb5 to
e4cd847
Compare
This PR adds a WAN 2.1 text2video and image2video model support with optimizations like TeaCache and step distilled models support. Based on original WAN 2.1 implementation and LightX2V
Basic commands:
1.3B text-to-video
14B text-to-video
14B image-to-video
Step distilled models:
T2V
I2V