Skip to content

Commit ecfbc8f

Browse files
authored
[Pipelines] Enable Wan VACE to run since single transformer (#12428)
* update * update * update * update * update
1 parent df0e2a4 commit ecfbc8f

File tree

2 files changed

+137
-25
lines changed

2 files changed

+137
-25
lines changed

src/diffusers/pipelines/wan/pipeline_wan_vace.py

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -152,34 +152,36 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
152152
text_encoder ([`T5EncoderModel`]):
153153
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
154154
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
155-
transformer ([`WanVACETransformer3DModel`]):
156-
Conditional Transformer to denoise the input latents.
157-
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
158-
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
159-
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
160-
`transformer` is used.
161-
scheduler ([`UniPCMultistepScheduler`]):
162-
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
163155
vae ([`AutoencoderKLWan`]):
164156
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
157+
scheduler ([`UniPCMultistepScheduler`]):
158+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
159+
transformer ([`WanVACETransformer3DModel`], *optional*):
160+
Conditional Transformer to denoise the input latents during the high-noise stage. In two-stage denoising,
161+
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
162+
`transformer` or `transformer_2` must be provided.
163+
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
164+
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
165+
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
166+
`transformer` or `transformer_2` must be provided.
165167
boundary_ratio (`float`, *optional*, defaults to `None`):
166168
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
167169
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
168170
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
169-
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
171+
boundary_timestep. If `None`, only the available transformer is used for the entire denoising process.
170172
"""
171173

172-
model_cpu_offload_seq = "text_encoder->transformer->vae"
174+
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
173175
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
174-
_optional_components = ["transformer_2"]
176+
_optional_components = ["transformer", "transformer_2"]
175177

176178
def __init__(
177179
self,
178180
tokenizer: AutoTokenizer,
179181
text_encoder: UMT5EncoderModel,
180-
transformer: WanVACETransformer3DModel,
181182
vae: AutoencoderKLWan,
182183
scheduler: FlowMatchEulerDiscreteScheduler,
184+
transformer: WanVACETransformer3DModel = None,
183185
transformer_2: WanVACETransformer3DModel = None,
184186
boundary_ratio: Optional[float] = None,
185187
):
@@ -336,7 +338,15 @@ def check_inputs(
336338
reference_images=None,
337339
guidance_scale_2=None,
338340
):
339-
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
341+
if self.transformer is not None:
342+
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
343+
elif self.transformer_2 is not None:
344+
base = self.vae_scale_factor_spatial * self.transformer_2.config.patch_size[1]
345+
else:
346+
raise ValueError(
347+
"`transformer` or `transformer_2` component must be set in order to run inference with this pipeline"
348+
)
349+
340350
if height % base != 0 or width % base != 0:
341351
raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.")
342352

@@ -414,7 +424,11 @@ def preprocess_conditions(
414424
device: Optional[torch.device] = None,
415425
):
416426
if video is not None:
417-
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
427+
base = self.vae_scale_factor_spatial * (
428+
self.transformer.config.patch_size[1]
429+
if self.transformer is not None
430+
else self.transformer_2.config.patch_size[1]
431+
)
418432
video_height, video_width = self.video_processor.get_default_height_width(video[0])
419433

420434
if video_height * video_width > height * width:
@@ -589,7 +603,11 @@ def prepare_masks(
589603
"Generating with more than one video is not yet supported. This may be supported in the future."
590604
)
591605

592-
transformer_patch_size = self.transformer.config.patch_size[1]
606+
transformer_patch_size = (
607+
self.transformer.config.patch_size[1]
608+
if self.transformer is not None
609+
else self.transformer_2.config.patch_size[1]
610+
)
593611

594612
mask_list = []
595613
for mask_, reference_images_batch in zip(mask, reference_images):
@@ -844,20 +862,25 @@ def __call__(
844862
batch_size = prompt_embeds.shape[0]
845863

846864
vae_dtype = self.vae.dtype
847-
transformer_dtype = self.transformer.dtype
865+
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
848866

867+
vace_layers = (
868+
self.transformer.config.vace_layers
869+
if self.transformer is not None
870+
else self.transformer_2.config.vace_layers
871+
)
849872
if isinstance(conditioning_scale, (int, float)):
850-
conditioning_scale = [conditioning_scale] * len(self.transformer.config.vace_layers)
873+
conditioning_scale = [conditioning_scale] * len(vace_layers)
851874
if isinstance(conditioning_scale, list):
852-
if len(conditioning_scale) != len(self.transformer.config.vace_layers):
875+
if len(conditioning_scale) != len(vace_layers):
853876
raise ValueError(
854-
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(self.transformer.config.vace_layers)}."
877+
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(vace_layers)}."
855878
)
856879
conditioning_scale = torch.tensor(conditioning_scale)
857880
if isinstance(conditioning_scale, torch.Tensor):
858-
if conditioning_scale.size(0) != len(self.transformer.config.vace_layers):
881+
if conditioning_scale.size(0) != len(vace_layers):
859882
raise ValueError(
860-
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(self.transformer.config.vace_layers)}."
883+
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(vace_layers)}."
861884
)
862885
conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype)
863886

@@ -900,7 +923,11 @@ def __call__(
900923
conditioning_latents = torch.cat([conditioning_latents, mask], dim=1)
901924
conditioning_latents = conditioning_latents.to(transformer_dtype)
902925

903-
num_channels_latents = self.transformer.config.in_channels
926+
num_channels_latents = (
927+
self.transformer.config.in_channels
928+
if self.transformer is not None
929+
else self.transformer_2.config.in_channels
930+
)
904931
latents = self.prepare_latents(
905932
batch_size * num_videos_per_prompt,
906933
num_channels_latents,
@@ -968,7 +995,7 @@ def __call__(
968995
attention_kwargs=attention_kwargs,
969996
return_dict=False,
970997
)[0]
971-
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
998+
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
972999

9731000
# compute the previous noisy sample x_t -> x_t-1
9741001
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

tests/pipelines/wan/test_wan_vace.py

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import tempfile
1516
import unittest
1617

1718
import numpy as np
1819
import torch
1920
from PIL import Image
2021
from transformers import AutoTokenizer, T5EncoderModel
2122

22-
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel
23+
from diffusers import (
24+
AutoencoderKLWan,
25+
FlowMatchEulerDiscreteScheduler,
26+
UniPCMultistepScheduler,
27+
WanVACEPipeline,
28+
WanVACETransformer3DModel,
29+
)
2330

24-
from ...testing_utils import enable_full_determinism
31+
from ...testing_utils import enable_full_determinism, torch_device
2532
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
2633
from ..test_pipelines_common import PipelineTesterMixin
2734

@@ -212,3 +219,81 @@ def test_float16_inference(self):
212219
)
213220
def test_save_load_float16(self):
214221
pass
222+
223+
def test_inference_with_only_transformer(self):
224+
components = self.get_dummy_components()
225+
components["transformer_2"] = None
226+
components["boundary_ratio"] = 0.0
227+
pipe = self.pipeline_class(**components)
228+
pipe.to(torch_device)
229+
pipe.set_progress_bar_config(disable=None)
230+
231+
inputs = self.get_dummy_inputs(torch_device)
232+
video = pipe(**inputs).frames[0]
233+
assert video.shape == (17, 3, 16, 16)
234+
235+
def test_inference_with_only_transformer_2(self):
236+
components = self.get_dummy_components()
237+
components["transformer_2"] = components["transformer"]
238+
components["transformer"] = None
239+
240+
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
241+
# because starting timestep t == 1000 == boundary_timestep
242+
components["scheduler"] = UniPCMultistepScheduler(
243+
prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
244+
)
245+
246+
components["boundary_ratio"] = 1.0
247+
pipe = self.pipeline_class(**components)
248+
pipe.to(torch_device)
249+
pipe.set_progress_bar_config(disable=None)
250+
251+
inputs = self.get_dummy_inputs(torch_device)
252+
video = pipe(**inputs).frames[0]
253+
assert video.shape == (17, 3, 16, 16)
254+
255+
def test_save_load_optional_components(self, expected_max_difference=1e-4):
256+
optional_component = ["transformer"]
257+
258+
components = self.get_dummy_components()
259+
components["transformer_2"] = components["transformer"]
260+
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
261+
# because starting timestep t == 1000 == boundary_timestep
262+
components["scheduler"] = UniPCMultistepScheduler(
263+
prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
264+
)
265+
for component in optional_component:
266+
components[component] = None
267+
268+
components["boundary_ratio"] = 1.0
269+
270+
pipe = self.pipeline_class(**components)
271+
for component in pipe.components.values():
272+
if hasattr(component, "set_default_attn_processor"):
273+
component.set_default_attn_processor()
274+
pipe.to(torch_device)
275+
pipe.set_progress_bar_config(disable=None)
276+
277+
generator_device = "cpu"
278+
inputs = self.get_dummy_inputs(generator_device)
279+
torch.manual_seed(0)
280+
output = pipe(**inputs)[0]
281+
282+
with tempfile.TemporaryDirectory() as tmpdir:
283+
pipe.save_pretrained(tmpdir, safe_serialization=False)
284+
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
285+
for component in pipe_loaded.components.values():
286+
if hasattr(component, "set_default_attn_processor"):
287+
component.set_default_attn_processor()
288+
pipe_loaded.to(torch_device)
289+
pipe_loaded.set_progress_bar_config(disable=None)
290+
291+
for component in optional_component:
292+
assert getattr(pipe_loaded, component) is None, f"`{component}` did not stay set to None after loading."
293+
294+
inputs = self.get_dummy_inputs(generator_device)
295+
torch.manual_seed(0)
296+
output_loaded = pipe_loaded(**inputs)[0]
297+
298+
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
299+
assert max_diff < expected_max_difference, "Outputs exceed expecpted maximum difference"

0 commit comments

Comments
 (0)