diff --git a/diffsynth/distributed/xdit_context_parallel.py b/diffsynth/distributed/xdit_context_parallel.py index 4887e2f1..07351a6e 100644 --- a/diffsynth/distributed/xdit_context_parallel.py +++ b/diffsynth/distributed/xdit_context_parallel.py @@ -40,9 +40,9 @@ def rope_apply(x, freqs, num_heads): return x_out.to(x.dtype) def usp_dit_forward(self, - x: torch.Tensor, + latents: torch.Tensor, timestep: torch.Tensor, - context: torch.Tensor, + encoder_hidden_states: torch.Tensor, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, use_gradient_checkpointing: bool = False, @@ -52,20 +52,20 @@ def usp_dit_forward(self, t = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, timestep)) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) - context = self.text_embedding(context) + encoder_hidden_states = self.text_embedding(encoder_hidden_states) if self.has_image_input: - x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + latents = torch.cat([latents, y], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = self.img_emb(clip_feature) - context = torch.cat([clip_embdding, context], dim=1) + encoder_hidden_states = torch.cat([clip_embdding, encoder_hidden_states], dim=1) - x, (f, h, w) = self.patchify(x) + latents, (f, h, w) = self.patchify(latents) freqs = torch.cat([ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) - ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + ], dim=-1).reshape(f * h * w, 1, -1).to(latents.device) def create_custom_forward(module): def custom_forward(*inputs): @@ -73,44 +73,44 @@ def custom_forward(*inputs): return custom_forward # Context Parallel - chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) + chunks = torch.chunk(latents, get_sequence_parallel_world_size(), dim=1) pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] - x = chunks[get_sequence_parallel_rank()] + latents = chunks[get_sequence_parallel_rank()] for block in self.blocks: if self.training and use_gradient_checkpointing: if use_gradient_checkpointing_offload: with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( + latents = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - x, context, t_mod, freqs, + latents, encoder_hidden_states, t_mod, freqs, use_reentrant=False, ) else: - x = torch.utils.checkpoint.checkpoint( + latents = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - x, context, t_mod, freqs, + latents, encoder_hidden_states, t_mod, freqs, use_reentrant=False, ) else: - x = block(x, context, t_mod, freqs) + latents = block(latents, encoder_hidden_states, t_mod, freqs) - x = self.head(x, t) + latents = self.head(latents, t) # Context Parallel - x = get_sp_group().all_gather(x, dim=1) - x = x[:, :-pad_shape] if pad_shape > 0 else x + latents = get_sp_group().all_gather(latents, dim=1) + latents = latents[:, :-pad_shape] if pad_shape > 0 else latents # unpatchify - x = self.unpatchify(x, (f, h, w)) - return x + latents = self.unpatchify(latents, (f, h, w)) + return latents -def usp_attn_forward(self, x, freqs): - q = self.norm_q(self.q(x)) - k = self.norm_k(self.k(x)) - v = self.v(x) +def usp_attn_forward(self, latents, freqs): + q = self.norm_q(self.q(latents)) + k = self.norm_k(self.k(latents)) + v = self.v(latents) q = rope_apply(q, freqs, self.num_heads) k = rope_apply(k, freqs, self.num_heads) @@ -118,14 +118,14 @@ def usp_attn_forward(self, x, freqs): k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads) v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads) - x = xFuserLongContextAttention()( + latents = xFuserLongContextAttention()( None, query=q, key=k, value=v, ) - x = x.flatten(2) + latents = latents.flatten(2) del q, k, v torch.cuda.empty_cache() - return self.o(x) \ No newline at end of file + return self.o(latents) \ No newline at end of file diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 1a54728f..1aea7ef8 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -211,7 +211,7 @@ def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) self.gate = GateModule() - def forward(self, x, context, t_mod, freqs): + def forward(self, hidden_states, encoder_hidden_states, t_mod, freqs): has_seq = len(t_mod.shape) == 4 chunk_dim = 2 if has_seq else 1 # msa: multi-head self-attention mlp: multi-layer perceptron @@ -222,12 +222,12 @@ def forward(self, x, context, t_mod, freqs): shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2), ) - input_x = modulate(self.norm1(x), shift_msa, scale_msa) - x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) - x = x + self.cross_attn(self.norm3(x), context) - input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) - x = self.gate(x, gate_mlp, self.ffn(input_x)) - return x + input_x = modulate(self.norm1(hidden_states), shift_msa, scale_msa) + hidden_states = self.gate(hidden_states, gate_msa, self.self_attn(input_x, freqs)) + hidden_states = hidden_states + self.cross_attn(self.norm3(hidden_states), encoder_hidden_states) + input_x = modulate(self.norm2(hidden_states), shift_mlp, scale_mlp) + hidden_states = self.gate(hidden_states, gate_mlp, self.ffn(input_x)) + return hidden_states class MLP(torch.nn.Module): @@ -244,10 +244,10 @@ def __init__(self, in_dim, out_dim, has_pos_emb=False): if has_pos_emb: self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) - def forward(self, x): + def forward(self, hidden_states): if self.has_pos_emb: - x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) - return self.proj(x) + hidden_states = hidden_states + self.emb_pos.to(dtype=hidden_states.dtype, device=hidden_states.device) + return self.proj(hidden_states) class Head(nn.Module): @@ -259,14 +259,14 @@ def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) - def forward(self, x, t_mod): + def forward(self, hidden_states, t_mod): if len(t_mod.shape) == 3: shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2) - x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))) + hidden_states = (self.head(self.norm(hidden_states) * (1 + scale.squeeze(2)) + shift.squeeze(2))) else: shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) - x = (self.head(self.norm(x) * (1 + scale) + shift)) - return x + hidden_states = (self.head(self.norm(hidden_states) * (1 + scale) + shift)) + return hidden_states class WanModel(torch.nn.Module): @@ -354,9 +354,9 @@ def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): ) def forward(self, - x: torch.Tensor, + hidden_states: torch.Tensor, timestep: torch.Tensor, - context: torch.Tensor, + encoder_hidden_states: torch.Tensor, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, use_gradient_checkpointing: bool = False, @@ -366,20 +366,20 @@ def forward(self, t = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, timestep)) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) - context = self.text_embedding(context) + context = self.text_embedding(encoder_hidden_states) if self.has_image_input: - x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = self.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) - x, (f, h, w) = self.patchify(x) + hidden_states, (f, h, w) = self.patchify(hidden_states) freqs = torch.cat([ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) - ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + ], dim=-1).reshape(f * h * w, 1, -1).to(hidden_states.device) def create_custom_forward(module): def custom_forward(*inputs): @@ -390,23 +390,23 @@ def custom_forward(*inputs): if self.training and use_gradient_checkpointing: if use_gradient_checkpointing_offload: with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( + hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, context, t_mod, freqs, use_reentrant=False, ) else: - x = torch.utils.checkpoint.checkpoint( + hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, context, t_mod, freqs, use_reentrant=False, ) else: - x = block(x, context, t_mod, freqs) + hidden_states = block(hidden_states, context, t_mod, freqs) - x = self.head(x, t) - x = self.unpatchify(x, (f, h, w)) - return x + hidden_states = self.head(hidden_states, t) + hidden_states = self.unpatchify(hidden_states, (f, h, w)) + return hidden_states @staticmethod def state_dict_converter(): diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py index 70881e6d..6e4c5ee9 100644 --- a/diffsynth/models/wan_video_dit_s2v.py +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -341,20 +341,20 @@ def forward(self, features): class WanS2VDiTBlock(DiTBlock): - def forward(self, x, context, t_mod, seq_len_x, freqs): + def forward(self, hidden_states, encoder_hidden_states, t_mod, seq_len_x, freqs): t_mod = (self.modulation.unsqueeze(2).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) # t_mod[:, :, 0] for x, t_mod[:, :, 1] for other like ref, motion, etc. t_mod = [ - torch.cat([element[:, :, 0].expand(1, seq_len_x, x.shape[-1]), element[:, :, 1].expand(1, x.shape[1] - seq_len_x, x.shape[-1])], dim=1) + torch.cat([element[:, :, 0].expand(1, seq_len_x, hidden_states.shape[-1]), element[:, :, 1].expand(1, hidden_states.shape[1] - seq_len_x, hidden_states.shape[-1])], dim=1) for element in t_mod ] shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = t_mod - input_x = modulate(self.norm1(x), shift_msa, scale_msa) - x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) - x = x + self.cross_attn(self.norm3(x), context) - input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) - x = self.gate(x, gate_mlp, self.ffn(input_x)) - return x + input_x = modulate(self.norm1(hidden_states), shift_msa, scale_msa) + hidden_states = self.gate(hidden_states, gate_msa, self.self_attn(input_x, freqs)) + hidden_states = hidden_states + self.cross_attn(self.norm3(hidden_states), encoder_hidden_states) + input_x = modulate(self.norm2(hidden_states), shift_mlp, scale_mlp) + hidden_states = self.gate(hidden_states, gate_mlp, self.ffn(input_x)) + return hidden_states class WanS2VModel(torch.nn.Module): @@ -505,94 +505,125 @@ def forward( self, latents, timestep, - context, + encoder_hidden_states, audio_input, motion_latents, pose_cond, use_gradient_checkpointing_offload=False, - use_gradient_checkpointing=False + use_gradient_checkpointing=False, + use_unified_sequence_parallel=False, + tea_cache=None, ): + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + + origin_ref_latents = latents[:, :, 0:1] - x = latents[:, :, 1:] + hidden_states = latents[:, :, 1:] # context embedding - context = self.text_embedding(context) + encoder_hidden_states = self.text_embedding(encoder_hidden_states) # audio encode audio_emb_global, merged_audio_emb = self.cal_audio_emb(audio_input) # x and pose_cond - pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond - x, (f, h, w) = self.patchify(self.patch_embedding(x) + self.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120]) - seq_len_x = x.shape[1] + pose_cond = torch.zeros_like(hidden_states) if pose_cond is None else pose_cond + hidden_states, (f, h, w) = self.patchify(self.patch_embedding(hidden_states) + self.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120]) + seq_len_x = seq_len_x_global = hidden_states.shape[1] # reference image ref_latents, (rf, rh, rw) = self.patchify(self.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120]) grid_sizes = self.get_grid_sizes((f, h, w), (rf, rh, rw)) - x = torch.cat([x, ref_latents], dim=1) + hidden_states = torch.cat([hidden_states, ref_latents], dim=1) # mask - mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device) + mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(hidden_states.device) # freqs pre_compute_freqs = rope_precompute( - x.detach().view(1, x.size(1), self.num_heads, self.dim // self.num_heads), grid_sizes, self.freqs, start=None + hidden_states.detach().view(1, hidden_states.size(1), self.num_heads, self.dim // self.num_heads), grid_sizes, self.freqs, start=None ) # motion - x, pre_compute_freqs, mask = self.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2) + hidden_states, pre_compute_freqs, mask = self.inject_motion(hidden_states, pre_compute_freqs, mask, motion_latents, add_last_motion=2) - x = x + self.trainable_cond_mask(mask).to(x.dtype) + hidden_states = hidden_states + self.trainable_cond_mask(mask).to(hidden_states.dtype) # t_mod timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2) - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward - - for block_id, block in enumerate(self.blocks): - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( + if tea_cache is not None: + tea_cache_update = tea_cache.check(self, hidden_states, t_mod) + else: + tea_cache_update = False + + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank() + assert hidden_states.shape[1] % world_size == 0, f"the dimension after chunk must be divisible by world size, but got {hidden_states.shape[1]} and {get_sequence_parallel_world_size()}" + hidden_states = torch.chunk(hidden_states, world_size, dim=1)[sp_rank] + seg_idxs = [0] + list(torch.cumsum(torch.tensor([hidden_states.shape[1]] * world_size), dim=0).cpu().numpy()) + seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), hidden_states.shape[1]) for i in range(len(seg_idxs)-1)] + seq_len_x = seq_len_x_list[sp_rank] + + if tea_cache_update: + hidden_states = tea_cache.update(hidden_states) + else: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block_id, block in enumerate(self.blocks): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + t_mod, + seq_len_x, + pre_compute_freqs[0], + use_reentrant=False, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + hidden_states, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - x, - context, + hidden_states, + encoder_hidden_states, t_mod, seq_len_x, pre_compute_freqs[0], use_reentrant=False, ) - x = torch.utils.checkpoint.checkpoint( + hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), - x, + hidden_states, use_reentrant=False, ) - elif use_gradient_checkpointing: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, - context, - t_mod, - seq_len_x, - pre_compute_freqs[0], - use_reentrant=False, - ) - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), - x, - use_reentrant=False, - ) - else: - x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) - x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x) + else: + hidden_states = block(hidden_states, encoder_hidden_states, t_mod, seq_len_x, pre_compute_freqs[0]) + hidden_states = self.after_transformer_block(block_id, hidden_states, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel=use_unified_sequence_parallel) + + if tea_cache is not None: + tea_cache.store(hidden_states) - x = x[:, :seq_len_x] - x = self.head(x, t[:-1]) - x = self.unpatchify(x, (f, h, w)) + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + hidden_states = get_sp_group().all_gather(hidden_states, dim=1) + + hidden_states = hidden_states[:, :seq_len_x_global] + hidden_states = self.head(hidden_states, t[:-1]) + hidden_states = self.unpatchify(hidden_states, (f, h, w)) # make compatible with wan video - x = torch.cat([origin_ref_latents, x], dim=2) - return x + hidden_states = torch.cat([origin_ref_latents, hidden_states], dim=2) + return hidden_states @staticmethod def state_dict_converter(): diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 660a38e7..f9266d50 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -30,7 +30,7 @@ class WanVideoPipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None): + def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None, offline_preprocessing=False): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 @@ -47,24 +47,9 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=Non self.in_iteration_models = ("dit", "motion_controller", "vace") self.in_iteration_models_2 = ("dit2", "motion_controller", "vace") self.unit_runner = PipelineUnitRunner() - self.units = [ - WanVideoUnit_ShapeChecker(), - WanVideoUnit_NoiseInitializer(), - WanVideoUnit_PromptEmbedder(), - WanVideoUnit_S2V(), - WanVideoUnit_InputVideoEmbedder(), - WanVideoUnit_ImageEmbedderVAE(), - WanVideoUnit_ImageEmbedderCLIP(), - WanVideoUnit_ImageEmbedderFused(), - WanVideoUnit_FunControl(), - WanVideoUnit_FunReference(), - WanVideoUnit_FunCameraControl(), - WanVideoUnit_SpeedControl(), - WanVideoUnit_VACE(), - WanVideoUnit_UnifiedSequenceParallel(), - WanVideoUnit_TeaCache(), - WanVideoUnit_CfgMerger(), - ] + + self.initalize_units(offline_preprocessing=offline_preprocessing) + self.post_units = [ WanVideoPostUnit_S2V(), ] @@ -90,7 +75,7 @@ def training_loss(self, **inputs): loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) loss = loss * self.scheduler.training_weight(timestep) - return loss + return loss, noise_pred def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5): @@ -285,7 +270,10 @@ def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=No def initialize_usp(self): import torch.distributed as dist from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment - dist.init_process_group(backend="nccl", init_method="env://") + + if not dist.is_initialized(): + dist.init_process_group(backend="nccl", init_method="env://") + init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) initialize_model_parallel( sequence_parallel_degree=dist.get_world_size(), @@ -301,7 +289,7 @@ def enable_usp(self): for block in self.dit.blocks: block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) - self.dit.forward = types.MethodType(usp_dit_forward, self.dit) + #self.dit.forward = types.MethodType(usp_dit_forward, self.dit) if self.dit2 is not None: for block in self.dit2.blocks: block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) @@ -319,6 +307,7 @@ def from_pretrained( audio_processor_config: ModelConfig = None, redirect_common_files: bool = True, use_usp=False, + offline_preprocessing=False ): # Redirect model path if redirect_common_files: @@ -335,7 +324,7 @@ def from_pretrained( model_config.model_id = redirect_dict[model_config.origin_file_pattern] # Initialize pipeline - pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) + pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype, offline_preprocessing=offline_preprocessing) if use_usp: pipe.initialize_usp() # Download and load models @@ -367,9 +356,10 @@ def from_pretrained( pipe.width_division_factor = pipe.vae.upsampling_factor * 2 # Initialize tokenizer - tokenizer_config.download_if_necessary(use_usp=use_usp) - pipe.prompter.fetch_models(pipe.text_encoder) - pipe.prompter.fetch_tokenizer(tokenizer_config.path) + if pipe.text_encoder is not None: + tokenizer_config.download_if_necessary(use_usp=use_usp) + pipe.prompter.fetch_models(pipe.text_encoder) + pipe.prompter.fetch_tokenizer(tokenizer_config.path) if audio_processor_config is not None: audio_processor_config.download_if_necessary(use_usp=use_usp) @@ -377,9 +367,45 @@ def from_pretrained( pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path) # Unified Sequence Parallel if use_usp: pipe.enable_usp() - return pipe + return pipe + def initalize_units(self, offline_preprocessing=False): + + if not offline_preprocessing: + self.units = [ + WanVideoUnit_ShapeChecker(), + WanVideoUnit_NoiseInitializer(), + WanVideoUnit_PromptEmbedder(), + WanVideoUnit_S2V(), + WanVideoUnit_InputVideoEmbedder(), + WanVideoUnit_ImageEmbedderVAE(), + WanVideoUnit_ImageEmbedderCLIP(), + WanVideoUnit_ImageEmbedderFused(), + WanVideoUnit_FunControl(), + WanVideoUnit_FunReference(), + WanVideoUnit_FunCameraControl(), + WanVideoUnit_SpeedControl(), + WanVideoUnit_VACE(), + WanVideoUnit_UnifiedSequenceParallel(), + WanVideoUnit_TeaCache(), + WanVideoUnit_CfgMerger() + ] + else: + self.units = [ + WanVideoUnit_ShapeChecker(), + WanVideoUnit_NoiseInitializer(), + WanVideoUnit_InputVideoEmbedderPassThrough(), + WanVideoUnit_ImageEmbedderFusingOnly(), + WanVideoUnit_FunControl(), + WanVideoUnit_FunReference(), + WanVideoUnit_FunCameraControl(), + WanVideoUnit_SpeedControl(), + WanVideoUnit_VACE(), + WanVideoUnit_UnifiedSequenceParallel(), + WanVideoUnit_TeaCache(), + WanVideoUnit_CfgMerger() + ] @torch.no_grad() def __call__( self, @@ -441,6 +467,7 @@ def __call__( tea_cache_model_id: Optional[str] = "", # progress_bar progress_bar_cmd=tqdm, + fps: Optional[int] = 16, ): # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) @@ -469,6 +496,7 @@ def __call__( "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video, + "fps": fps, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) @@ -560,12 +588,24 @@ def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, vace_reference_image = pipe.preprocess_video([vace_reference_image]) vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) input_latents = torch.concat([vace_reference_latents, input_latents], dim=2) - if pipe.scheduler.training: + elif pipe.scheduler.training: return {"latents": noise, "input_latents": input_latents} else: latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) return {"latents": latents} +class WanVideoUnit_InputVideoEmbedderPassThrough(PipelineUnit): + def __init__(self): + super().__init__(input_params=("input_latents", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image")) + + def process(self, pipe: WanVideoPipeline, input_latents, noise, tiled, tile_size, tile_stride, vace_reference_image): + if input_latents is None: + return {"latents": noise} + elif pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents} class WanVideoUnit_PromptEmbedder(PipelineUnit): @@ -582,8 +622,6 @@ def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict: prompt_emb = pipe.prompter.encode_prompt(prompt, positive=positive, device=pipe.device) return {"context": prompt_emb} - - class WanVideoUnit_ImageEmbedder(PipelineUnit): """ Deprecated @@ -691,15 +729,32 @@ def __init__(self): onload_model_names=("vae",) ) - def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride): + def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride, pre_encoding=False): if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) - latents[:, :, 0: 1] = z - return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z} + if pre_encoding: + return {"first_frame_latents": z} + else: + latents[:, :, 0: 1] = z + return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z} +class WanVideoUnit_ImageEmbedderFusingOnly(PipelineUnit): + """ + Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. + """ + def __init__(self): + super().__init__( + input_params=("latents","first_frame_latents") + ) + + def process(self, pipe: WanVideoPipeline, latents,first_frame_latents): + if first_frame_latents is None or not pipe.dit.fuse_vae_embedding_in_latents: + return {} + latents[:, :, 0: 1] = first_frame_latents + return {"latents": latents, "fuse_vae_embedding_in_latents": True} class WanVideoUnit_FunControl(PipelineUnit): @@ -869,7 +924,9 @@ def __init__(self): def process(self, pipe: WanVideoPipeline): if hasattr(pipe, "use_unified_sequence_parallel"): if pipe.use_unified_sequence_parallel: + print("use_unified_sequence_parallel true") return {"use_unified_sequence_parallel": True} + return {} @@ -917,7 +974,7 @@ def __init__(self): onload_model_names=("audio_encoder", "vae",) ) - def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps=16, audio_embeds=None, return_all=False): + def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps, audio_embeds=None, return_all=False): if audio_embeds is not None: return {"audio_embeds": audio_embeds} pipe.load_models_to_device(["audio_encoder"]) @@ -971,8 +1028,8 @@ def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_neg num_frames, height, width, tiled, tile_size, tile_stride = inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop("input_audio"), inputs_shared.pop("audio_embeds"), inputs_shared.get("audio_sample_rate") s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop("s2v_pose_video"), inputs_shared.pop("s2v_pose_latents"), inputs_shared.pop("motion_video") - - audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, audio_embeds=audio_embeds) + fps = inputs_shared.get("fps") + audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, fps, audio_embeds=audio_embeds) inputs_posi.update(audio_input_positive) inputs_nega.update({"audio_embeds": 0.0 * audio_input_positive["audio_embeds"]}) @@ -1011,7 +1068,7 @@ def __init__(self, num_inference_steps, rel_l1_thresh, model_id): self.previous_modulated_input = None self.rel_l1_thresh = rel_l1_thresh self.previous_residual = None - self.previous_hidden_states = None + self.previous_hidden_states = None self.coefficients_dict = { "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], @@ -1162,20 +1219,20 @@ def model_fn_wan_video( tensor_names=["latents", "y"], batch_size=2 if cfg_merge else 1 ) + # wan2.2 s2v if audio_embeds is not None: - return model_fn_wans2v( - dit=dit, + return dit( latents=latents, timestep=timestep, - context=context, - audio_embeds=audio_embeds, + encoder_hidden_states=context, + audio_input=audio_embeds, + pose_cond=s2v_pose_latents, motion_latents=motion_latents, - s2v_pose_latents=s2v_pose_latents, - drop_motion_frames=drop_motion_frames, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, use_gradient_checkpointing=use_gradient_checkpointing, use_unified_sequence_parallel=use_unified_sequence_parallel, + tea_cache=tea_cache, ) if use_unified_sequence_parallel: @@ -1310,6 +1367,7 @@ def model_fn_wans2v( use_gradient_checkpointing_offload=False, use_gradient_checkpointing=False, use_unified_sequence_parallel=False, + tea_cache=None, ): if use_unified_sequence_parallel: import torch.distributed as dist @@ -1348,6 +1406,11 @@ def model_fn_wans2v( t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2) + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, x, t_mod) + else: + tea_cache_update = False + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank() assert x.shape[1] % world_size == 0, f"the dimension after chunk must be divisible by world size, but got {x.shape[1]} and {get_sequence_parallel_world_size()}" @@ -1356,14 +1419,28 @@ def model_fn_wans2v( seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), x.shape[1]) for i in range(len(seg_idxs)-1)] seq_len_x = seq_len_x_list[sp_rank] - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward + if tea_cache_update: + x = tea_cache.update(x) + else: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward - for block_id, block in enumerate(dit.blocks): - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): + for block_id, block in enumerate(dit.blocks): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, seq_len_x, pre_compute_freqs[0], + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + x, + use_reentrant=False, + ) + elif use_gradient_checkpointing: x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, context, t_mod, seq_len_x, pre_compute_freqs[0], @@ -1374,20 +1451,12 @@ def custom_forward(*inputs): x, use_reentrant=False, ) - elif use_gradient_checkpointing: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, seq_len_x, pre_compute_freqs[0], - use_reentrant=False, - ) - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), - x, - use_reentrant=False, - ) - else: - x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) - x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel) + else: + x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) + x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel) + + if tea_cache is not None: + tea_cache.store(x) if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: x = get_sp_group().all_gather(x, dim=1) diff --git a/diffsynth/utils/__init__.py b/diffsynth/utils/__init__.py index ec3c7270..94603bc6 100644 --- a/diffsynth/utils/__init__.py +++ b/diffsynth/utils/__init__.py @@ -227,7 +227,7 @@ def __init__( input_params: tuple[str] = None, input_params_posi: dict[str, str] = None, input_params_nega: dict[str, str] = None, - onload_model_names: tuple[str] = None + onload_model_names: tuple[str] = None, ): self.seperate_cfg = seperate_cfg self.take_over = take_over @@ -247,6 +247,7 @@ def __init__(self): pass def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]: + if unit.take_over: # Let the pipeline unit take over this function. inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega) diff --git a/examples/wanvideo/model_inference/Wan2.2-S2V-14B.py b/examples/wanvideo/model_inference/Wan2.2-S2V-14B.py index bb93871a..426276f8 100644 --- a/examples/wanvideo/model_inference/Wan2.2-S2V-14B.py +++ b/examples/wanvideo/model_inference/Wan2.2-S2V-14B.py @@ -45,6 +45,8 @@ audio_sample_rate=sample_rate, input_audio=input_audio, num_inference_steps=40, + tea_cache_l1_thresh=0.05, + tea_cache_model_id="Wan2.1-I2V-14B-480P", ) save_video_with_audio(video[1:], "video_with_audio.mp4", audio_path, fps=16, quality=5)