Skip to content

Commit 694361f

Browse files
authored
Support Wan FLF2V (#388)
* update progress on Wan FLF2V * update * update docs
1 parent f2dfb11 commit 694361f

File tree

4 files changed

+81
-24
lines changed

4 files changed

+81
-24
lines changed

docs/models/wan.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@ chmod +x ./examples/training/sft/wan/crush_smol_lora/train.sh
1818

1919
On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows]
2020

21+
## Supported checkpoints
22+
23+
Wan has multiple checkpoints as one can find [here](https://huggingface.co/Wan-AI). The following checkpoints were tested with `finetrainers` and are known to be working:
24+
25+
- [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers)
26+
- [Wan-AI/Wan2.1-T2V-14B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers)
27+
- [Wan-AI/Wan2.1-I2V-14B-480P-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers)
28+
- [Wan-AI/Wan2.1-I2V-14B-720P-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P-Diffusers)
29+
- [Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers)
30+
2131
## Inference
2232

2333
Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,37 @@
1+
from typing import Optional
2+
13
import torch
24

35

4-
def normalize(x: torch.Tensor, min: float = -1.0, max: float = 1.0) -> torch.Tensor:
5-
"""Normalize a tensor to the range [min_val, max_val]."""
6-
x_min = x.min()
7-
x_max = x.max()
8-
if torch.isclose(x_min, x_max).any():
9-
x = torch.full_like(x, min)
6+
def normalize(x: torch.Tensor, min: float = -1.0, max: float = 1.0, dim: Optional[int] = None) -> torch.Tensor:
7+
"""
8+
Normalize a tensor to the range [min_val, max_val].
9+
10+
Args:
11+
x (`torch.Tensor`):
12+
The input tensor to normalize.
13+
min (`float`, defaults to `-1.0`):
14+
The minimum value of the normalized range.
15+
max (`float`, defaults to `1.0`):
16+
The maximum value of the normalized range.
17+
dim (`int`, *optional*):
18+
The dimension along which to normalize. If `None`, the entire tensor is normalized.
19+
20+
Returns:
21+
The normalized tensor of the same shape as `x`.
22+
"""
23+
if dim is None:
24+
x_min = x.min()
25+
x_max = x.max()
26+
if torch.isclose(x_min, x_max).any():
27+
x = torch.full_like(x, min)
28+
else:
29+
x = min + (max - min) * (x - x_min) / (x_max - x_min)
1030
else:
11-
x = min + (max - min) * (x - x_min) / (x_max - x_min)
31+
x_min = x.amin(dim=dim, keepdim=True)
32+
x_max = x.amax(dim=dim, keepdim=True)
33+
if torch.isclose(x_min, x_max).any():
34+
x = torch.full_like(x, min)
35+
else:
36+
x = min + (max - min) * (x - x_min) / (x_max - x_min)
1237
return x

finetrainers/models/wan/base_specification.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,10 @@ class WanImageConditioningLatentEncodeProcessor(ProcessorMixin):
9494
- mask: The conditioning frame mask for the input image/video.
9595
"""
9696

97-
def __init__(self, output_names: List[str]):
97+
def __init__(self, output_names: List[str], *, use_last_frame: bool = False):
9898
super().__init__()
9999
self.output_names = output_names
100+
self.use_last_frame = use_last_frame
100101
assert len(self.output_names) == 4
101102

102103
def forward(
@@ -117,8 +118,12 @@ def forward(
117118
video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
118119

119120
num_frames = video.size(2)
120-
first_frame, remaining_frames = video[:, :, :1], video[:, :, 1:]
121-
video = torch.cat([first_frame, torch.zeros_like(remaining_frames)], dim=2)
121+
if not self.use_last_frame:
122+
first_frame, remaining_frames = video[:, :, :1], video[:, :, 1:]
123+
video = torch.cat([first_frame, torch.zeros_like(remaining_frames)], dim=2)
124+
else:
125+
first_frame, remaining_frames, last_frame = video[:, :, :1], video[:, :, 1:-1], video[:, :, -1:]
126+
video = torch.cat([first_frame, torch.zeros_like(remaining_frames), last_frame], dim=2)
122127

123128
# Image conditioning uses argmax sampling, so we use "mode" here
124129
if compute_posterior:
@@ -139,7 +144,10 @@ def forward(
139144

140145
temporal_downsample = 2 ** sum(vae.temperal_downsample) if getattr(self, "vae", None) else 4
141146
mask = latents.new_ones(latents.shape[0], 1, num_frames, latents.shape[3], latents.shape[4])
142-
mask[:, :, 1:] = 0
147+
if not self.use_last_frame:
148+
mask[:, :, 1:] = 0
149+
else:
150+
mask[:, :, 1:-1] = 0
143151
first_frame_mask = mask[:, :, :1]
144152
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=temporal_downsample)
145153
mask = torch.cat([first_frame_mask, mask[:, :, 1:]], dim=2)
@@ -164,9 +172,10 @@ class WanImageEncodeProcessor(ProcessorMixin):
164172
- image_embeds: The CLIP vision model image embeddings of the input image.
165173
"""
166174

167-
def __init__(self, output_names: List[str]):
175+
def __init__(self, output_names: List[str], *, use_last_frame: bool = False):
168176
super().__init__()
169177
self.output_names = output_names
178+
self.use_last_frame = use_last_frame
170179
assert len(self.output_names) == 1
171180

172181
def forward(
@@ -178,15 +187,19 @@ def forward(
178187
) -> Dict[str, torch.Tensor]:
179188
device = image_encoder.device
180189
dtype = image_encoder.dtype
181-
182-
if video is not None:
183-
image = video[:, 0] # [B, F, C, H, W] -> [B, C, H, W] (take first frame)
184-
185-
assert image.ndim == 4, f"Expected 4D tensor, got {image.ndim}D tensor"
190+
last_image = None
186191

187192
# We know the image here is in the range [-1, 1] (probably a little overshot if using bilinear interpolation), but
188193
# the processor expects it to be in the range [0, 1].
189-
image = FF.normalize(image, min=0.0, max=1.0)
194+
image = image if video is None else video[:, 0] # [B, F, C, H, W] -> [B, C, H, W] (take first frame)
195+
image = FF.normalize(image, min=0.0, max=1.0, dim=1)
196+
assert image.ndim == 4, f"Expected 4D tensor, got {image.ndim}D tensor"
197+
198+
if self.use_last_frame:
199+
last_image = image if video is None else video[:, -1]
200+
last_image = FF.normalize(last_image, min=0.0, max=1.0, dim=1)
201+
image = torch.stack([image, last_image], dim=0)
202+
190203
image = image_processor(images=image.float(), do_rescale=False, do_convert_rgb=False, return_tensors="pt")
191204
image = image.to(device=device, dtype=dtype)
192205
image_embeds = image_encoder(**image, output_hidden_states=True)
@@ -224,18 +237,23 @@ def __init__(
224237
cache_dir=cache_dir,
225238
)
226239

240+
use_last_frame = self.transformer_config.pos_embed_seq_len is not None
241+
227242
if condition_model_processors is None:
228-
condition_model_processors = [T5Processor(["encoder_hidden_states", "prompt_attention_mask"])]
243+
condition_model_processors = [T5Processor(["encoder_hidden_states", "__drop__"])]
229244
if latent_model_processors is None:
230245
latent_model_processors = [WanLatentEncodeProcessor(["latents", "latents_mean", "latents_std"])]
231246

232247
if self.transformer_config.image_dim is not None:
233248
latent_model_processors.append(
234249
WanImageConditioningLatentEncodeProcessor(
235-
["latent_condition", "__drop__", "__drop__", "latent_condition_mask"]
250+
["latent_condition", "__drop__", "__drop__", "latent_condition_mask"],
251+
use_last_frame=use_last_frame,
236252
)
237253
)
238-
latent_model_processors.append(WanImageEncodeProcessor(["encoder_hidden_states_image"]))
254+
latent_model_processors.append(
255+
WanImageEncodeProcessor(["encoder_hidden_states_image"], use_last_frame=use_last_frame)
256+
)
239257

240258
self.condition_model_processors = condition_model_processors
241259
self.latent_model_processors = latent_model_processors
@@ -380,7 +398,6 @@ def prepare_conditions(
380398
input_keys = set(conditions.keys())
381399
conditions = super().prepare_conditions(**conditions)
382400
conditions = {k: v for k, v in conditions.items() if k not in input_keys}
383-
conditions.pop("prompt_attention_mask", None)
384401
return conditions
385402

386403
@torch.no_grad()
@@ -480,6 +497,7 @@ def validation(
480497
pipeline: Union[WanPipeline, WanImageToVideoPipeline],
481498
prompt: str,
482499
image: Optional[PIL.Image.Image] = None,
500+
last_image: Optional[PIL.Image.Image] = None,
483501
video: Optional[List[PIL.Image.Image]] = None,
484502
height: Optional[int] = None,
485503
width: Optional[int] = None,
@@ -501,9 +519,11 @@ def validation(
501519
if self.transformer_config.image_dim is not None:
502520
if image is None and video is None:
503521
raise ValueError("Either image or video must be provided for Wan I2V validation.")
504-
if image is None:
505-
image = video[0]
522+
image = image if image is not None else video[0]
506523
generation_kwargs["image"] = image
524+
if self.transformer_config.pos_embed_seq_len is not None:
525+
last_image = last_image if last_image is not None else image if video is None else video[-1]
526+
generation_kwargs["last_image"] = last_image
507527
generation_kwargs = get_non_null_items(generation_kwargs)
508528
video = pipeline(**generation_kwargs).frames[0]
509529
return [VideoArtifact(value=video)]

finetrainers/trainer/control_trainer/data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def _run_control_processors(self, data: Dict[str, Any]) -> Dict[str, Any]:
115115
if "control_output" in shallow_copy_data:
116116
# Normalize to [-1, 1] range
117117
control_output = shallow_copy_data.pop("control_output")
118+
# TODO(aryan): need to specify a dim for normalize here across channels
118119
control_output = FF.normalize(control_output, min=-1.0, max=1.0)
119120
key = "control_image" if is_image_control else "control_video"
120121
shallow_copy_data[key] = control_output
@@ -182,6 +183,7 @@ def _run_control_processors(self, data: Dict[str, Any]) -> Dict[str, Any]:
182183
# Normalize to [-1, 1] range
183184
control_output = shallow_copy_data.pop("control_output")
184185
if torch.is_tensor(control_output):
186+
# TODO(aryan): need to specify a dim for normalize here across channels
185187
control_output = FF.normalize(control_output, min=-1.0, max=1.0)
186188
ndim = control_output.ndim
187189
assert 3 <= ndim <= 5, "Control output should be at least ndim=3 and less than or equal to ndim=5"

0 commit comments

Comments
 (0)