Skip to content

Commit c089372

Browse files
committed
update
1 parent 20d738c commit c089372

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ def __init__(self, in_features: int, out_features: int):
115115
self.norm2 = nn.LayerNorm(out_features)
116116

117117
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
118-
hidden_states = self.norm1(encoder_hidden_states_image)
118+
hidden_states = self.norm1(encoder_hidden_states_image.float()).type_as(encoder_hidden_states_image)
119119
hidden_states = self.ff(hidden_states)
120-
hidden_states = self.norm2(hidden_states)
120+
hidden_states = self.norm2(hidden_states.float()).type_as(encoder_hidden_states_image)
121121
return hidden_states
122122

123123

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import PIL
2020
import regex as re
2121
import torch
22-
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
22+
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection, UMT5EncoderModel
2323

2424
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2525
from ...image_processor import PipelineImageInput
@@ -137,7 +137,7 @@ def __init__(
137137
self,
138138
tokenizer: AutoTokenizer,
139139
text_encoder: UMT5EncoderModel,
140-
image_encoder: CLIPVisionModel,
140+
image_encoder: CLIPVisionModelWithProjection,
141141
image_processor: CLIPImageProcessor,
142142
transformer: WanTransformer3DModel,
143143
vae: AutoencoderKLWan,
@@ -345,9 +345,6 @@ def prepare_latents(
345345
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
346346
latents: Optional[torch.Tensor] = None,
347347
) -> Tuple[torch.Tensor, torch.Tensor]:
348-
if latents is not None:
349-
return latents.to(device=device, dtype=dtype)
350-
351348
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
352349
latent_height = height // self.vae_scale_factor_spatial
353350
latent_width = width // self.vae_scale_factor_spatial
@@ -359,11 +356,14 @@ def prepare_latents(
359356
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
360357
)
361358

362-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
359+
if latents is None:
360+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
361+
else:
362+
latents = latents.to(device=device, dtype=dtype)
363363

364364
image = image.unsqueeze(2)
365365
video_condition = torch.cat(
366-
[image, torch.zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
366+
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
367367
)
368368
video_condition = video_condition.to(device=device, dtype=dtype)
369369

@@ -564,7 +564,7 @@ def __call__(
564564
timesteps = self.scheduler.timesteps
565565

566566
# 5. Prepare latent variables
567-
num_channels_latents = self.transformer.config.in_channels
567+
num_channels_latents = self.vae.config.z_dim
568568
image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
569569
latents, condition = self.prepare_latents(
570570
image,

0 commit comments

Comments
 (0)