diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py new file mode 100644 index 000000000000..65da51946b21 --- /dev/null +++ b/scripts/convert_wan_to_diffusers.py @@ -0,0 +1,196 @@ +import argparse +import pathlib +from typing import Any, Dict + +import torch +from accelerate import init_empty_weights +from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer +from huggingface_hub import snapshot_download +from safetensors.torch import load_file + +from diffusers import ( + AutoencoderKLWan, + FlowMatchEulerDiscreteScheduler, + WanPipeline, + WanTransformer3DModel, +) + + +TRANSFORMER_KEYS_RENAME_DICT = { + "time_embedding.0": "condition_embedder.time_embedder.linear_1", + "time_embedding.2": "condition_embedder.time_embedder.linear_2", + "text_embedding.0": "condition_embedder.text_embedder.linear_1", + "text_embedding.2": "condition_embedder.text_embedder.linear_2", + "time_projection.1": "condition_embedder.time_proj", + "head.modulation": "scale_shift_table", + "head.head": "proj_out", + "modulation": "scale_shift_table", + "ffn.0": "ffn.net.0.proj", + "ffn.2": "ffn.net.2", + # Hack to swap the layer names + # The original model calls the norms in following order: norm1, norm3, norm2 + # We convert it to: norm1, norm2, norm3 + "norm2": "norm__placeholder", + "norm3": "norm2", + "norm__placeholder": "norm3", +} + +TRANSFORMER_SPECIAL_KEYS_REMAP = {} + +VAE_KEYS_RENAME_DICT = {} + +VAE_SPECIAL_KEYS_REMAP = {} + + +def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: + state_dict[new_key] = state_dict.pop(old_key) + + +def load_sharded_safetensors(dir: pathlib.Path): + file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors")) + print(file_paths) + state_dict = {} + for path in file_paths: + state_dict.update(load_file(path)) + return state_dict + + +def get_transformer_config(model_type: str) -> Dict[str, Any]: + if model_type == "Wan-T2V-1.3B": + config = { + "model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff", + "diffusers_config": { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 8960, + "freq_dim": 256, + "in_channels": 16, + "num_attention_heads": 12, + "num_layers": 30, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + } + } + elif model_type == "Wan-T2V-14B": + config = { + "model_id": "StevenZhang/Wan2.1-T2V-14B-Diff", + "diffusers_config": { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 16, + "num_attention_heads": 40, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + } + } + return config + + +def convert_transformer(model_type: str): + config = get_transformer_config(model_type) + diffusers_config = config["diffusers_config"] + model_id = config["model_id"] + model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model")) + + original_state_dict = load_sharded_safetensors(model_dir) + + with init_empty_weights(): + transformer = WanTransformer3DModel.from_config(diffusers_config) + + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + transformer.load_state_dict(original_state_dict, strict=True, assign=True) + return transformer + + +# def convert_vae(ckpt_path: str): +# original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) + +# with init_empty_weights(): +# vae = AutoencoderKLWan() + +# for key in list(original_state_dict.keys()): +# new_key = key[:] +# for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): +# new_key = new_key.replace(replace_key, rename_key) +# update_state_dict_(original_state_dict, key, new_key) + +# for key in list(original_state_dict.keys()): +# for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): +# if special_key not in key: +# continue +# handler_fn_inplace(key, original_state_dict) + +# vae.load_state_dict(original_state_dict, strict=True, assign=True) +# return vae + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_type", type=str, default=None) + # parser.add_argument("--vae_ckpt_path", type=str, default=None) + # parser.add_argument("--text_encoder_path", type=str, default=None) + # parser.add_argument("--tokenizer_path", type=str, default=None) + # parser.add_argument("--save_pipeline", action="store_true") + parser.add_argument("--output_path", type=str, required=True) + parser.add_argument("--dtype", default="fp32") + return parser.parse_args() + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +if __name__ == "__main__": + args = get_args() + + transformer = None + dtype = DTYPE_MAPPING[args.dtype] + + transformer = convert_transformer(args.model_type) + transformer = transformer.to(dtype=dtype) + transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + # if args.vae_ckpt_path is not None: + # vae = convert_vae(args.vae_ckpt_path) + # if not args.save_pipeline: + # vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + # if args.save_pipeline: + # # TODO(aryan): update these + # text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16) + # tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right") + # scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + + # pipe = WanPipeline( + # transformer=transformer, + # vae=vae, + # text_encoder=text_encoder, + # tokenizer=tokenizer, + # scheduler=scheduler, + # ) + # pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index b5e6c57de45e..c99c39e074bc 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -111,6 +111,7 @@ AutoencoderKLLTXVideo, AutoencoderKLMochi, AutoencoderKLTemporalDecoder, + AutoencoderKLWan, AutoencoderOobleck, AutoencoderTiny, ConsistencyDecoderVAE, diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 4291598f5421..6f91f0453336 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -19,16 +19,14 @@ import torch.nn as nn import torch.nn.functional as F -from diffusers.loaders import FromOriginalModelMixin - from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import logging from ..attention import FeedForward -from ..attention_processor import Attention, AttentionProcessor -from ..cache_utils import CacheMixin +from ..attention_processor import Attention +from ..embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin +from ..normalization import FP32LayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -47,25 +45,16 @@ def __call__( hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - grid_sizes: Optional[torch.Tensor] = None, - freqs: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - batch_size, _, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - # i2v task encoder_hidden_states_img = None if attn.add_k_proj is not None: encoder_hidden_states_img = encoder_hidden_states[:, :257] encoder_hidden_states = encoder_hidden_states[:, 257:] - - - query = attn.to_q(hidden_states) - if encoder_hidden_states is None: encoder_hidden_states = hidden_states - + + query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -74,161 +63,159 @@ def __call__( if attn.norm_k is not None: key = attn.norm_k(key) - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim) - key = key.view(batch_size, -1, attn.heads, head_dim) - value = value.view(batch_size, -1, attn.heads, head_dim) - - if grid_sizes is not None and freqs is not None: - query = rope_apply(query, grid_sizes, freqs) - key = rope_apply(key, grid_sizes, freqs) - - query = query.transpose(1, 2) - - # i2v task + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if rotary_emb is not None: + def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): + x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) + x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) + return x_out.type_as(hidden_states) + + query = apply_rotary_emb(query, rotary_emb) + key = apply_rotary_emb(key, rotary_emb) + + # I2V task hidden_states_img = None if encoder_hidden_states_img is not None: key_img = attn.add_k_proj(encoder_hidden_states_img) - key_img = attn.norm_added_k(key_img).view(batch_size, -1, attn.heads, head_dim) - value_img = attn.add_v_proj(encoder_hidden_states_img).view(batch_size, -1, attn.heads, head_dim) - key_img = key_img.transpose(1, 2) - value_img = value_img.transpose(1, 2) + key_img = attn.norm_added_k(key_img) + value_img = attn.add_v_proj(encoder_hidden_states_img) + + key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) + hidden_states_img = F.scaled_dot_product_attention( query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False ) hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) - - - key = key.transpose(1, 2) - value = value.transpose(1, 2) + hidden_states_img = hidden_states_img.type_as(query) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) - hidden_states = hidden_states.to(query.dtype) + hidden_states = hidden_states.type_as(query) if hidden_states_img is not None: hidden_states = hidden_states + hidden_states_img - # linear proj hidden_states = attn.to_out[0](hidden_states) - # dropout hidden_states = attn.to_out[1](hidden_states) - return hidden_states -def sinusoidal_embedding_1d(dim, position): - # preprocess - assert dim % 2 == 0 - half = dim // 2 - position = position.type(torch.float64) - - # calculation - sinusoid = torch.outer( - position, torch.pow(10000, -torch.arange(half).to(position).div(half))) - x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) - return x - - -@torch.cuda.amp.autocast(enabled=False) -def rope_params(max_seq_len, dim, theta=10000): - assert dim % 2 == 0 - freqs = torch.outer( - torch.arange(max_seq_len), - 1.0 / torch.pow(theta, - torch.arange(0, dim, 2).to(torch.float64).div(dim))) - freqs = torch.polar(torch.ones_like(freqs), freqs) - return freqs - +class WanImageEmbedding(torch.nn.Module): + def __init__(self, in_features: int, out_features: int): + super().__init__() -@torch.cuda.amp.autocast(enabled=False) -def rope_apply(x, grid_sizes, freqs): - n, c = x.size(2), x.size(3) // 2 + self.norm1 = nn.LayerNorm(in_features) + self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu") + self.norm2 = nn.LayerNorm(out_features, elementwise_affine=False) - # split freqs - freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm1(encoder_hidden_states_image) + hidden_states = self.ff(hidden_states) + hidden_states = self.norm2(hidden_states) + return hidden_states - # loop over samples - output = [] - for i, (f, h, w) in enumerate(grid_sizes.tolist()): - seq_len = f * h * w - # precompute multipliers - x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape( - seq_len, n, -1, 2)) - freqs_i = torch.cat([ - freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), - freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), - freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) - ], - dim=-1).reshape(seq_len, 1, -1) +class WanTimeTextImageEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + image_embedding_dim: Optional[int] = None, + ): + super().__init__() - # apply rotary embedding - x_i = torch.view_as_real(x_i * freqs_i).flatten(2) - x_i = torch.cat([x_i, x[i, seq_len:]]) + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + + self.image_embedder = None + if image_embedding_dim is not None: + self.image_embedder = WanImageEmbedding(image_embedding_dim, dim) - # append to collection - output.append(x_i) - return torch.stack(output).float() + def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None): + timestep = self.timesteps_proj(timestep) + # TODO: We should remove the type_as here. `_keep_modules_in_fp32` is not actually keeping time_embedder layer in fp32 + temb = self.time_embedder(timestep.type_as(encoder_hidden_states)) + timestep_proj = self.time_proj(self.act_fn(temb)) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image -class WanRMSNorm(nn.Module): - def __init__(self, dim, eps=1e-5): +class WanRotaryPosEmbed(nn.Module): + def __init__(self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0): super().__init__() - self.dim = dim - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def forward(self, x): - return self._norm(x.float()).type_as(x) * self.weight - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + self.attention_head_dim = attention_head_dim + self.patch_size = patch_size + self.max_seq_len = max_seq_len -class WanLayerNorm(nn.LayerNorm): - - def __init__(self, dim, eps=1e-6, elementwise_affine=False): - super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) - - def forward(self, x): - return super().forward(x.float()).type_as(x) + h_dim = w_dim = 2 * (attention_head_dim // 6) + t_dim = attention_head_dim - h_dim - w_dim + freqs = [] + for dim in [t_dim, h_dim, w_dim]: + freq = get_1d_rotary_pos_embed(dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64) + freqs.append(freq) + self.freqs = torch.cat(freqs, dim=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w + + self.freqs = self.freqs.to(hidden_states.device) + freqs = self.freqs.split_with_sizes( + [self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), self.attention_head_dim // 6, self.attention_head_dim // 6], dim=1 + ) + + freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) + return freqs -class WanBlock(nn.Module): +class WanTransformerBlock(nn.Module): def __init__(self, - dim, - ffn_dim, - num_heads, - window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=False, - eps=1e-6, - added_kv_proj_dim=None): + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: Optional[int] = None + ): super().__init__() self.dim = dim self.ffn_dim = ffn_dim self.num_heads = num_heads - self.window_size = window_size self.qk_norm = qk_norm self.cross_attn_norm = cross_attn_norm self.eps = eps - self.norm1 = WanLayerNorm(dim, eps) - # self attn + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) self.attn1 = Attention( query_dim=dim, heads=num_heads, kv_heads=num_heads, dim_head=dim // num_heads, - qk_norm="rms_norm_across_heads" if qk_norm else None, + qk_norm=qk_norm, eps=eps, bias=True, cross_attention_dim=None, @@ -236,13 +223,13 @@ def __init__(self, processor=WanAttnProcessor2_0(), ) - # cross attn + # 2. Cross-attention self.attn2 = Attention( query_dim=dim, heads=num_heads, kv_heads=num_heads, dim_head=dim // num_heads, - qk_norm="rms_norm_across_heads" if qk_norm else None, + qk_norm=qk_norm, eps=eps, bias=True, cross_attention_dim=None, @@ -251,106 +238,42 @@ def __init__(self, added_proj_bias=True, processor=WanAttnProcessor2_0(), ) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + + # 3. Feed-forward + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) - self.norm3 = WanLayerNorm( - dim, eps, - elementwise_affine=True) if cross_attn_norm else nn.Identity() - - self.norm2 = WanLayerNorm(dim, eps) - self.ffn = nn.Sequential( - nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), - nn.Linear(ffn_dim, dim)) - - # modulation - self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) def forward( self, - hidden_states, - e, - encoder_hidden_states, - seq_lens, - grid_sizes, - freqs, - context_lens, - attention_kwargs: Optional[Dict[str, Any]] = None, - ): - assert e.dtype == torch.float32 - with torch.cuda.amp.autocast(dtype=torch.float32): - e = (self.modulation + e).chunk(6, dim=1) - assert e[0].dtype == torch.float32 - - # self-attention - attn_hidden_states = self.norm1(hidden_states).float() * (1 + e[1]) + e[0] - - attn_hidden_states = self.attn1( - hidden_states=attn_hidden_states, - grid_sizes=grid_sizes, - freqs=freqs, - ) - with torch.cuda.amp.autocast(dtype=torch.float32): - hidden_states = hidden_states + attn_hidden_states * e[2] - - # cross-attention - attn_hidden_states = self.norm3(hidden_states) - attn_hidden_states = self.attn2( - hidden_states=attn_hidden_states, - encoder_hidden_states=encoder_hidden_states, - grid_sizes=None, - freqs=None, - ) - hidden_states = hidden_states + attn_hidden_states - - # ffn - ffn_hidden_states = self.norm2(hidden_states).float() * (1 + e[4]) + e[3] - ffn_hidden_states = self.ffn(ffn_hidden_states) - with torch.cuda.amp.autocast(dtype=torch.float32): - hidden_states = hidden_states + ffn_hidden_states * e[5] + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: torch.Tensor, + ) -> torch.Tensor: + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (self.scale_shift_table + temb.float()).chunk(6, dim=1) + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb) + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + + # 2. Cross-attention + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states) + ff_output = self.ffn(norm_hidden_states) + hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) return hidden_states -class WanHead(nn.Module): - - def __init__(self, dim, out_dim, patch_size, eps=1e-6): - super().__init__() - self.dim = dim - self.out_dim = out_dim - self.patch_size = patch_size - self.eps = eps - - # layers - out_dim = math.prod(patch_size) * out_dim - self.norm = WanLayerNorm(dim, eps) - self.head = nn.Linear(dim, out_dim) - - # modulation - self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) - - def forward(self, x, e): - assert e.dtype == torch.float32 - with torch.cuda.amp.autocast(dtype=torch.float32): - e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) - x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) - return x - - -class MLPProj(torch.nn.Module): - - def __init__(self, in_dim, out_dim): - super().__init__() - - self.proj = torch.nn.Sequential( - torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), - torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), - torch.nn.LayerNorm(out_dim)) - - def forward(self, image_embeds): - clip_extra_context_tokens = self.proj(image_embeds) - return clip_extra_context_tokens - - -class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): +class WanTransformer3DModel(ModelMixin, ConfigMixin): r""" A Transformer model for video-like data used in the Wan model. @@ -388,11 +311,9 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi """ _supports_gradient_checkpointing = True - _skip_layerwise_casting_patterns = ["patch_embedding", "text_embedding", "time_embedding", "time_projection"] - _no_split_modules = [ - "WanBlock", - "WanHead", - ] + _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] + _no_split_modules = ["WanTransformerBlock"] + _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] @register_to_config def __init__( @@ -402,226 +323,92 @@ def __init__( attention_head_dim: int = 128, in_channels: int = 16, out_channels: int = 16, - text_dim: int = 512, + text_dim: int = 4096, freq_dim: int = 256, ffn_dim: int = 13824, num_layers: int = 40, - window_size: Tuple[int] = (-1, -1), cross_attn_norm: bool = True, - qk_norm: bool = True, + qk_norm: Optional[str] = "rms_norm_across_heads", eps: float = 1e-6, - add_img_emb: bool = False, + image_embedding_dim: Optional[int] = None, added_kv_proj_dim: Optional[int] = None, + rope_max_seq_len: int = 1024, ) -> None: super().__init__() inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels - self.freq_dim = freq_dim - self.out_channels = out_channels - self.patch_size = patch_size - - # embeddings - self.patch_embedding = nn.Conv3d( - in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) - self.text_embedding = nn.Sequential( - nn.Linear(text_dim, inner_dim), nn.GELU(approximate='tanh'), - nn.Linear(inner_dim, inner_dim)) - - self.time_embedding = nn.Sequential( - nn.Linear(freq_dim, inner_dim), nn.SiLU(), nn.Linear(inner_dim, inner_dim)) - self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(inner_dim, inner_dim * 6)) + # 1. Patch & position embedding + self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + # 2. Condition embeddings + # image_embedding_dim=1280 for I2V model + self.condition_embedder = WanTimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embedding_dim=image_embedding_dim, + ) - # blocks + # 3. Transformer blocks self.blocks = nn.ModuleList([ - WanBlock(inner_dim, ffn_dim, num_attention_heads, - window_size, qk_norm, cross_attn_norm, eps, - added_kv_proj_dim) + WanTransformerBlock(inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim) for _ in range(num_layers) ]) - # head - self.head = WanHead(inner_dim, out_channels, patch_size, eps) - - # buffers (don't use register_buffer otherwise dtype will be changed in to()) - assert attention_head_dim % 2 == 0 - self.freqs = torch.cat([ - rope_params(1024, attention_head_dim - 4 * (attention_head_dim // 6)), - rope_params(1024, 2 * (attention_head_dim // 6)), - rope_params(1024, 2 * (attention_head_dim // 6)) - ], - dim=1) - - self.add_img_emb = add_img_emb - if add_img_emb: - self.img_emb = MLPProj(1280, inner_dim) + # 4. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) self.gradient_checkpointing = False - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - def forward( self, hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, - seq_len: int, - attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_hidden_states_image: Optional[torch.Tensor] = None, return_dict: bool = True, - **kwargs ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w - if self.freqs.device != hidden_states.device: - self.freqs = self.freqs.to(hidden_states.device) + rotary_emb = self.rope(hidden_states) hidden_states = self.patch_embedding(hidden_states) - - grid_sizes = torch.stack( - [torch.tensor(u.shape[1:], dtype=torch.long) for u in hidden_states] - ) - hidden_states = hidden_states.flatten(2).transpose(1, 2) # (b, l, c) - seq_lens = torch.tensor([u.size(0) for u in hidden_states], dtype=torch.long) - assert seq_lens.max() <= seq_len - hidden_states = torch.cat([ - torch.cat([u.unsqueeze(0), u.new_zeros(1, seq_len - u.size(0), u.size(1))], - dim=1) for u in hidden_states - ]) - - with torch.cuda.amp.autocast(dtype=torch.float32): - e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, timestep).float()) - e0 = self.time_projection(e).unflatten(1, (6, -1)) - assert e.dtype == torch.float32 and e0.dtype == torch.float32 + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image) + timestep_proj = timestep_proj.unflatten(1, (6, -1)) - context_lens = None - encoder_hidden_states = self.text_embedding(encoder_hidden_states) - if self.add_img_emb: - img_encoder_hidden_states = kwargs.get('img_encoder_hidden_states', None) - if img_encoder_hidden_states is None: - raise ValueError('`add_img_emb` is set but `img_encoder_hidden_states` is not provided.') - img_encoder_hidden_states = self.img_emb(img_encoder_hidden_states) - encoder_hidden_states = torch.concat([img_encoder_hidden_states, encoder_hidden_states], dim=1) + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: for block in self.blocks: - hidden_states = self._gradient_checkpointing_func( - block, - hidden_states, - e0, - encoder_hidden_states, - seq_lens, - grid_sizes, - self.freqs, - context_lens, - attention_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) else: for block in self.blocks: - hidden_states = block( - hidden_states, - e0, - encoder_hidden_states, - seq_lens, - grid_sizes, - self.freqs, - context_lens, - attention_kwargs, - ) - - # Output projection - hidden_states = self.head(hidden_states, e) - - # 5. Unpatchify - hidden_states = self.unpatchify(hidden_states, grid_sizes) - hidden_states = torch.stack(hidden_states) + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) - if not return_dict: - return (hidden_states,) - - return Transformer2DModelOutput(sample=hidden_states) + # 5. Output norm, projection & unpatchify + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) - def unpatchify(self, hidden_states, grid_sizes): - c = self.out_channels - out = [] - for u, v in zip(hidden_states, grid_sizes.tolist()): - u = u[:math.prod(v)].view(*v, *self.patch_size, c) - u = torch.einsum('fhwpqrc->cfphqwr', u) - u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) - out.append(u) - return out + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/wan/__init__.py b/src/diffusers/pipelines/wan/__init__.py index 51fec66b99d4..61ed95d8651c 100644 --- a/src/diffusers/pipelines/wan/__init__.py +++ b/src/diffusers/pipelines/wan/__init__.py @@ -33,8 +33,8 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: - from .pipeline_wanx import WanxPipeline - from .pipeline_wanx_i2v import WanxI2VPipeline + from .pipeline_wan import WanPipeline + from .pipeline_wan_i2v import WanI2VPipeline else: import sys diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 3a5829eb3b0b..962f809171a8 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -26,7 +26,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...models import AutoencoderKLWan, WanTransformer3DModel -from ...schedulers import UniPCMultistepScheduler +from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor @@ -138,7 +138,7 @@ def __init__( text_encoder: UMT5EncoderModel, transformer: WanTransformer3DModel, vae: AutoencoderKLWan, - scheduler: UniPCMultistepScheduler, + scheduler: FlowMatchEulerDiscreteScheduler, ): super().__init__() @@ -150,7 +150,6 @@ def __init__( scheduler=scheduler, ) - self.patch_size = self.transformer.patch_size self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -200,6 +199,7 @@ def encode_prompt( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = False, num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, @@ -217,6 +217,8 @@ def encode_prompt( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. num_videos_per_prompt (`int`, *optional*, defaults to 1): Number of videos that should be generated per prompt. torch device to place the resulting embeddings on prompt_embeds (`torch.Tensor`, *optional*): @@ -248,7 +250,7 @@ def encode_prompt( dtype=dtype, ) - if negative_prompt_embeds is None: + if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt @@ -350,12 +352,12 @@ def guidance_scale(self): return self._guidance_scale @property - def num_timesteps(self): - return self._num_timesteps + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 @property - def attention_kwargs(self): - return self._attention_kwargs + def num_timesteps(self): + return self._num_timesteps @property def current_timestep(self): @@ -382,16 +384,13 @@ def __call__( latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, - attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - autocast_dtype: torch.dtype = torch.bfloat16, ): r""" The call function to the pipeline for generation. @@ -431,10 +430,6 @@ def __call__( The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. - attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: @@ -471,7 +466,6 @@ def __call__( ) self._guidance_scale = guidance_scale - self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False @@ -489,16 +483,18 @@ def __call__( prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, device=device, - dtype=autocast_dtype, ) - prompt_embeds = prompt_embeds.to(autocast_dtype) - negative_prompt_embeds = negative_prompt_embeds.to(autocast_dtype) + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) # 4. Prepare timesteps self.scheduler.flow_shift = flow_shift @@ -526,45 +522,31 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) - *_, latent_f, latent_h, latent_w = latents.shape - seq_len = math.ceil((latent_h * latent_w) / - (self.patch_size[1] * self.patch_size[2]) * - latent_f) - - with ( - self.progress_bar(total=num_inference_steps) as progress_bar, - amp.autocast('cuda', dtype=autocast_dtype, cache_enabled=False) - ): + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue self._current_timestep = t - latent_model_input = latents - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, - seq_len=seq_len, - attention_kwargs=attention_kwargs, return_dict=False, )[0] - noise_pred_negative = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - seq_len=seq_len, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - noise_pred = noise_pred_negative + guidance_scale * ( - noise_pred - noise_pred_negative) - + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 722b63f3687a..ac015032b684 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -238,10 +238,12 @@ def encode_image(self, image: PipelineImageInput): image_embeds = self.image_encoder(**image, output_hidden_states=True) return image_embeds.hidden_states[31] + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = False, num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, @@ -259,6 +261,8 @@ def encode_prompt( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. num_videos_per_prompt (`int`, *optional*, defaults to 1): Number of videos that should be generated per prompt. torch device to place the resulting embeddings on prompt_embeds (`torch.Tensor`, *optional*): @@ -290,7 +294,7 @@ def encode_prompt( dtype=dtype, ) - if negative_prompt_embeds is None: + if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt @@ -368,7 +372,7 @@ def prepare_latents( device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, - ) -> (torch.Tensor, torch.Tensor): + ) -> Tuple[torch.Tensor, torch.Tensor]: aspect_ratio = height / width mod_value = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value @@ -440,12 +444,12 @@ def guidance_scale(self): return self._guidance_scale @property - def num_timesteps(self): - return self._num_timesteps + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 @property - def attention_kwargs(self): - return self._attention_kwargs + def num_timesteps(self): + return self._num_timesteps @property def current_timestep(self): @@ -472,16 +476,13 @@ def __call__( latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, - attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - autocast_dtype: torch.dtype = torch.bfloat16, ): r""" The call function to the pipeline for generation. @@ -522,10 +523,6 @@ def __call__( The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. - attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: @@ -563,7 +560,6 @@ def __call__( ) self._guidance_scale = guidance_scale - self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False @@ -581,20 +577,22 @@ def __call__( prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, device=device, - dtype=autocast_dtype, ) - # encode image embedding + + # Encode image embedding image_embeds = self.encode_image(image) image_embeds = image_embeds.repeat(batch_size, 1, 1) - prompt_embeds = prompt_embeds.to(autocast_dtype) - negative_prompt_embeds = negative_prompt_embeds.to(autocast_dtype) - image_embeds = image_embeds.to(autocast_dtype) + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + image_embeds = image_embeds.to(transformer_dtype) # 4. Prepare timesteps self.scheduler.flow_shift = flow_shift @@ -606,6 +604,7 @@ def __call__( height, width = image.shape[-2:] else: width, height = image.size + # 5. Prepare latent variables num_channels_latents = self.vae.config.z_dim num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 @@ -628,48 +627,32 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) - *_, num_latent_frames, latent_h, latent_w = latents.shape - seq_len = num_latent_frames * latent_h * latent_w // ( - self.transformer.config.patch_size[0] * \ - self.transformer.config.patch_size[1] * \ - self.transformer.config.patch_size[2] - ) - - with ( - self.progress_bar(total=num_inference_steps) as progress_bar, - amp.autocast('cuda', dtype=autocast_dtype, cache_enabled=False) - ): + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue self._current_timestep = t - latent_model_input = latents - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) timestep = t.expand(latents.shape[0]) noise_pred = self.transformer( - hidden_states=torch.concat([latent_model_input, condition], dim=1), + hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, - img_encoder_hidden_states=image_embeds, - seq_len=seq_len, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - noise_pred_negative = self.transformer( - hidden_states=torch.concat([latent_model_input, condition], dim=1), - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - img_encoder_hidden_states=image_embeds, - seq_len=seq_len, - attention_kwargs=attention_kwargs, + encoder_hidden_states_image=image_embeds, return_dict=False, )[0] - noise_pred = noise_pred_negative + guidance_scale * ( - noise_pred - noise_pred_negative) + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 9dd1e690742f..a4c715b5801d 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -111,6 +111,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLWan(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLAllegro(metaclass=DummyObject): _backends = ["torch"] @@ -966,6 +981,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class WanTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def get_constant_schedule(*args, **kwargs): requires_backends(get_constant_schedule, ["torch"]) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index c853cf8faa55..f15da5215487 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2595,3 +2595,33 @@ def from_config(cls, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) + + +class WanPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class WanI2VPipelin(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) diff --git a/tests/models/transformers/test_models_transformer_wan.py b/tests/models/transformers/test_models_transformer_wan.py new file mode 100644 index 000000000000..3ac64c628988 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_wan.py @@ -0,0 +1,81 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import WanTransformer3DModel +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase): + model_class = WanTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 1 + num_channels = 4 + num_frames = 2 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + sequence_length = 12 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + } + + @property + def input_shape(self): + return (4, 1, 16, 16) + + @property + def output_shape(self): + return (4, 1, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 12, + "in_channels": 4, + "out_channels": 4, + "text_dim": 16, + "freq_dim": 256, + "ffn_dim": 32, + "num_layers": 2, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "rope_max_seq_len": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"WanTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set)