Skip to content

Commit cb8e61e

Browse files
[wan2.2] follow-up (#12024)
* up --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 8e53cd9 commit cb8e61e

File tree

7 files changed

+897
-74
lines changed

7 files changed

+897
-74
lines changed

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def forward(
324324
):
325325
timestep = self.timesteps_proj(timestep)
326326
if timestep_seq_len is not None:
327-
timestep = timestep.unflatten(0, (1, timestep_seq_len))
327+
timestep = timestep.unflatten(0, (-1, timestep_seq_len))
328328

329329
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
330330
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:

src/diffusers/pipelines/wan/pipeline_wan.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,15 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
125125

126126
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
127127
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
128-
_optional_components = ["transformer_2"]
128+
_optional_components = ["transformer", "transformer_2"]
129129

130130
def __init__(
131131
self,
132132
tokenizer: AutoTokenizer,
133133
text_encoder: UMT5EncoderModel,
134-
transformer: WanTransformer3DModel,
135134
vae: AutoencoderKLWan,
136135
scheduler: FlowMatchEulerDiscreteScheduler,
136+
transformer: Optional[WanTransformer3DModel] = None,
137137
transformer_2: Optional[WanTransformer3DModel] = None,
138138
boundary_ratio: Optional[float] = None,
139139
expand_timesteps: bool = False, # Wan2.2 ti2v
@@ -526,7 +526,7 @@ def __call__(
526526
device=device,
527527
)
528528

529-
transformer_dtype = self.transformer.dtype
529+
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
530530
prompt_embeds = prompt_embeds.to(transformer_dtype)
531531
if negative_prompt_embeds is not None:
532532
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
@@ -536,7 +536,11 @@ def __call__(
536536
timesteps = self.scheduler.timesteps
537537

538538
# 5. Prepare latent variables
539-
num_channels_latents = self.transformer.config.in_channels
539+
num_channels_latents = (
540+
self.transformer.config.in_channels
541+
if self.transformer is not None
542+
else self.transformer_2.config.in_channels
543+
)
540544
latents = self.prepare_latents(
541545
batch_size * num_videos_per_prompt,
542546
num_channels_latents,

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,17 +162,17 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
162162

163163
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
164164
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
165-
_optional_components = ["transformer_2", "image_encoder", "image_processor"]
165+
_optional_components = ["transformer", "transformer_2", "image_encoder", "image_processor"]
166166

167167
def __init__(
168168
self,
169169
tokenizer: AutoTokenizer,
170170
text_encoder: UMT5EncoderModel,
171-
transformer: WanTransformer3DModel,
172171
vae: AutoencoderKLWan,
173172
scheduler: FlowMatchEulerDiscreteScheduler,
174173
image_processor: CLIPImageProcessor = None,
175174
image_encoder: CLIPVisionModel = None,
175+
transformer: WanTransformer3DModel = None,
176176
transformer_2: WanTransformer3DModel = None,
177177
boundary_ratio: Optional[float] = None,
178178
expand_timesteps: bool = False,
@@ -669,12 +669,13 @@ def __call__(
669669
)
670670

671671
# Encode image embedding
672-
transformer_dtype = self.transformer.dtype
672+
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
673673
prompt_embeds = prompt_embeds.to(transformer_dtype)
674674
if negative_prompt_embeds is not None:
675675
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
676676

677-
if self.config.boundary_ratio is None and not self.config.expand_timesteps:
677+
# only wan 2.1 i2v transformer accepts image_embeds
678+
if self.transformer is not None and self.transformer.config.image_dim is not None:
678679
if image_embeds is None:
679680
if last_image is None:
680681
image_embeds = self.encode_image(image, device)
@@ -709,6 +710,7 @@ def __call__(
709710
last_image,
710711
)
711712
if self.config.expand_timesteps:
713+
# wan 2.2 5b i2v use firt_frame_mask to mask timesteps
712714
latents, condition, first_frame_mask = latents_outputs
713715
else:
714716
latents, condition = latents_outputs

tests/pipelines/wan/test_wan.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# limitations under the License.
1414

1515
import gc
16+
import tempfile
1617
import unittest
1718

19+
import numpy as np
1820
import torch
1921
from transformers import AutoTokenizer, T5EncoderModel
2022

@@ -85,29 +87,13 @@ def get_dummy_components(self):
8587
rope_max_seq_len=32,
8688
)
8789

88-
torch.manual_seed(0)
89-
transformer_2 = WanTransformer3DModel(
90-
patch_size=(1, 2, 2),
91-
num_attention_heads=2,
92-
attention_head_dim=12,
93-
in_channels=16,
94-
out_channels=16,
95-
text_dim=32,
96-
freq_dim=256,
97-
ffn_dim=32,
98-
num_layers=2,
99-
cross_attn_norm=True,
100-
qk_norm="rms_norm_across_heads",
101-
rope_max_seq_len=32,
102-
)
103-
10490
components = {
10591
"transformer": transformer,
10692
"vae": vae,
10793
"scheduler": scheduler,
10894
"text_encoder": text_encoder,
10995
"tokenizer": tokenizer,
110-
"transformer_2": transformer_2,
96+
"transformer_2": None,
11197
}
11298
return components
11399

@@ -155,6 +141,45 @@ def test_inference(self):
155141
def test_attention_slicing_forward_pass(self):
156142
pass
157143

144+
# _optional_components include transformer, transformer_2, but only transformer_2 is optional for this wan2.1 t2v pipeline
145+
def test_save_load_optional_components(self, expected_max_difference=1e-4):
146+
optional_component = "transformer_2"
147+
148+
components = self.get_dummy_components()
149+
components[optional_component] = None
150+
pipe = self.pipeline_class(**components)
151+
for component in pipe.components.values():
152+
if hasattr(component, "set_default_attn_processor"):
153+
component.set_default_attn_processor()
154+
pipe.to(torch_device)
155+
pipe.set_progress_bar_config(disable=None)
156+
157+
generator_device = "cpu"
158+
inputs = self.get_dummy_inputs(generator_device)
159+
torch.manual_seed(0)
160+
output = pipe(**inputs)[0]
161+
162+
with tempfile.TemporaryDirectory() as tmpdir:
163+
pipe.save_pretrained(tmpdir, safe_serialization=False)
164+
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
165+
for component in pipe_loaded.components.values():
166+
if hasattr(component, "set_default_attn_processor"):
167+
component.set_default_attn_processor()
168+
pipe_loaded.to(torch_device)
169+
pipe_loaded.set_progress_bar_config(disable=None)
170+
171+
self.assertTrue(
172+
getattr(pipe_loaded, optional_component) is None,
173+
f"`{optional_component}` did not stay set to None after loading.",
174+
)
175+
176+
inputs = self.get_dummy_inputs(generator_device)
177+
torch.manual_seed(0)
178+
output_loaded = pipe_loaded(**inputs)[0]
179+
180+
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
181+
self.assertLess(max_diff, expected_max_difference)
182+
158183

159184
@slow
160185
@require_torch_accelerator

0 commit comments

Comments
 (0)