diff --git a/fastvideo/attention/layer.py b/fastvideo/attention/layer.py index 7b086a3c6..0fba2644a 100644 --- a/fastvideo/attention/layer.py +++ b/fastvideo/attention/layer.py @@ -11,7 +11,7 @@ from fastvideo.forward_context import ForwardContext, get_forward_context from fastvideo.platforms import AttentionBackendEnum from fastvideo.utils import get_compute_dtype - +from fastvideo.layers.rotary_embedding import _apply_rotary_emb class DistributedAttention(nn.Module): """Distributed attention layer. @@ -64,6 +64,7 @@ def forward( replicated_q: torch.Tensor | None = None, replicated_k: torch.Tensor | None = None, replicated_v: torch.Tensor | None = None, + freqs_cis: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """Forward pass for distributed attention. @@ -97,6 +98,11 @@ def forward( qkv = sequence_model_parallel_all_to_all_4D(qkv, scatter_dim=2, gather_dim=1) + if freqs_cis is not None: + cos, sin = freqs_cis + # apply to q and k + qkv[:batch_size*2] = _apply_rotary_emb(qkv[:batch_size*2], cos, sin, is_neox_style=False) + # Apply backend-specific preprocess_qkv qkv = self.attn_impl.preprocess_qkv(qkv, ctx_attn_metadata) @@ -147,6 +153,7 @@ def forward( replicated_k: torch.Tensor | None = None, replicated_v: torch.Tensor | None = None, gate_compress: torch.Tensor | None = None, + freqs_cis: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """Forward pass for distributed attention. @@ -172,7 +179,7 @@ def forward( forward_context: ForwardContext = get_forward_context() ctx_attn_metadata = forward_context.attn_metadata - + batch_size, seq_len, num_heads, head_dim = q.shape # Stack QKV qkvg = torch.cat([q, k, v, gate_compress], dim=0) # [3, seq_len, num_heads, head_dim] @@ -182,9 +189,14 @@ def forward( scatter_dim=2, gather_dim=1) - qkvg = self.attn_impl.preprocess_qkv(qkvg, ctx_attn_metadata) + if freqs_cis is not None: + cos, sin = freqs_cis + qkvg[:batch_size*2] = _apply_rotary_emb(qkvg[:batch_size*2], cos, sin, is_neox_style=False) + qkvg = self.attn_impl.preprocess_qkv(qkvg, ctx_attn_metadata) + q, k, v, gate_compress = qkvg.chunk(4, dim=0) + output = self.attn_impl.forward( q, k, v, gate_compress, ctx_attn_metadata) # type: ignore[call-arg] @@ -244,6 +256,7 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + freqs_cis: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: """ Apply local attention between query, key and value tensors. @@ -262,6 +275,10 @@ def forward( forward_context: ForwardContext = get_forward_context() ctx_attn_metadata = forward_context.attn_metadata + if freqs_cis is not None: + cos, sin = freqs_cis + q = _apply_rotary_emb(q, cos, sin, is_neox_style=False) + k = _apply_rotary_emb(k, cos, sin, is_neox_style=False) output = self.attn_impl.forward(q, k, v, ctx_attn_metadata) return output diff --git a/fastvideo/distributed/communication_op.py b/fastvideo/distributed/communication_op.py index c1cad53c4..676e0bc52 100644 --- a/fastvideo/distributed/communication_op.py +++ b/fastvideo/distributed/communication_op.py @@ -4,7 +4,10 @@ import torch import torch.distributed -from fastvideo.distributed.parallel_state import get_sp_group, get_tp_group +from fastvideo.distributed.parallel_state import (get_sp_group, + get_sp_parallel_rank, + get_sp_world_size, + get_tp_group) def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: @@ -30,3 +33,16 @@ def sequence_model_parallel_all_gather(input_: torch.Tensor, dim: int = -1) -> torch.Tensor: """All-gather the input tensor across model parallel group.""" return get_sp_group().all_gather(input_, dim) + +def sequence_model_parallel_shard(input_: torch.Tensor, + dim: int = 1) -> torch.Tensor: + """Shard the input tensor across model parallel group.""" + sp_rank = get_sp_parallel_rank() + sp_world_size = get_sp_world_size() + assert input_.shape[dim] % sp_world_size == 0, f"input tensor dim={dim} must be divisible by sp_world_size={sp_world_size}" + elements_per_rank = input_.shape[dim] // sp_world_size + # sharding dim + input_ = input_.movedim(dim, 0) + input_ = input_[sp_rank*elements_per_rank:(sp_rank+1)*elements_per_rank] + input_ = input_.movedim(0, dim) + return input_ diff --git a/fastvideo/entrypoints/video_generator.py b/fastvideo/entrypoints/video_generator.py index 2e6f63696..cb16e3bc6 100644 --- a/fastvideo/entrypoints/video_generator.py +++ b/fastvideo/entrypoints/video_generator.py @@ -232,27 +232,20 @@ def _generate_single_video( orig_latent_num_frames = sampling_param.num_frames // 17 * 3 if orig_latent_num_frames % fastvideo_args.num_gpus != 0: - # Adjust latent frames to be divisible by number of GPUs - if sampling_param.num_frames_round_down: - # Ensure we have at least 1 batch per GPU - new_latent_num_frames = max( - 1, (orig_latent_num_frames // num_gpus)) * num_gpus - else: - new_latent_num_frames = math.ceil( - orig_latent_num_frames / num_gpus) * num_gpus + if use_temporal_scaling_frames: # Convert back to number of frames, ensuring num_frames-1 is a multiple of temporal_scale_factor - new_num_frames = (new_latent_num_frames - + new_num_frames = (orig_latent_num_frames - 1) * temporal_scale_factor + 1 else: # stepvideo only # Find the least common multiple of 3 and num_gpus divisor = math.lcm(3, num_gpus) # Round up to the nearest multiple of this LCM - new_latent_num_frames = ( - (new_latent_num_frames + divisor - 1) // divisor) * divisor + orig_latent_num_frames = ( + (orig_latent_num_frames + divisor - 1) // divisor) * divisor # Convert back to actual frames using the StepVideo formula - new_num_frames = new_latent_num_frames // 3 * 17 + new_num_frames = orig_latent_num_frames // 3 * 17 logger.info( "Adjusting number of frames from %s to %s based on number of GPUs (%s)", diff --git a/fastvideo/layers/rotary_embedding.py b/fastvideo/layers/rotary_embedding.py index 6abe90609..5c81ace2c 100644 --- a/fastvideo/layers/rotary_embedding.py +++ b/fastvideo/layers/rotary_embedding.py @@ -369,6 +369,7 @@ def get_rotary_pos_embed( theta_rescale_factor=1.0, interpolation_factor=1.0, shard_dim: int = 0, + do_sp_sharding: bool = False, dtype: torch.dtype = torch.float32, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -383,7 +384,7 @@ def get_rotary_pos_embed( theta_rescale_factor: Rescale factor for theta. Defaults to 1.0 interpolation_factor: Factor to scale positions. Defaults to 1.0 shard_dim: Which dimension to shard for sequence parallelism. Defaults to 0. - + do_sp_sharding: Whether to shard the positional embeddings for sequence parallelism. Defaults to False. Returns: Tuple of (cos, sin) tensors for rotary embeddings """ @@ -399,9 +400,13 @@ def get_rotary_pos_embed( ) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" # Get SP info - sp_group = get_sp_group() - sp_rank = sp_group.rank_in_group - sp_world_size = sp_group.world_size + if do_sp_sharding: + sp_group = get_sp_group() + sp_rank = sp_group.rank_in_group + sp_world_size = sp_group.world_size + else: + sp_rank = 0 + sp_world_size = 1 freqs_cos, freqs_sin = get_nd_rotary_pos_embed( rope_dim_list, diff --git a/fastvideo/models/dits/hunyuanvideo.py b/fastvideo/models/dits/hunyuanvideo.py index da14322ba..f49728fbf 100644 --- a/fastvideo/models/dits/hunyuanvideo.py +++ b/fastvideo/models/dits/hunyuanvideo.py @@ -23,6 +23,7 @@ from fastvideo.models.dits.base import CachableDiT from fastvideo.models.utils import modulate from fastvideo.platforms import AttentionBackendEnum +from fastvideo.distributed.communication_op import sequence_model_parallel_shard, sequence_model_parallel_all_gather class HunyuanRMSNorm(nn.Module): @@ -239,14 +240,7 @@ def forward( img_q = self.img_attn_q_norm(img_q).to(img_v) img_k = self.img_attn_k_norm(img_k).to(img_v) - # Apply rotary embeddings - cos, sin = freqs_cis - img_q, img_k = _apply_rotary_emb( - img_q, cos, sin, - is_neox_style=False), _apply_rotary_emb(img_k, - cos, - sin, - is_neox_style=False) + # Prepare text for attention using fused operation txt_attn_input = self.txt_attn_norm(txt, txt_attn_shift, txt_attn_scale) @@ -265,7 +259,7 @@ def forward( txt_k = self.txt_attn_k_norm(txt_k).to(txt_k.dtype) # Run distributed attention - img_attn, txt_attn = self.attn(img_q, img_k, img_v, txt_q, txt_k, txt_v) + img_attn, txt_attn = self.attn(img_q, img_k, img_v, txt_q, txt_k, txt_v, freqs_cis=freqs_cis) img_attn_out, _ = self.img_attn_proj( img_attn.view(batch_size, image_seq_len, -1)) # Use fused operation for residual connection, normalization, and modulation @@ -395,18 +389,11 @@ def forward( img_q, txt_q = q[:, :-txt_len], q[:, -txt_len:] img_k, txt_k = k[:, :-txt_len], k[:, -txt_len:] img_v, txt_v = v[:, :-txt_len], v[:, -txt_len:] - # Apply rotary embeddings to image parts - cos, sin = freqs_cis - img_q, img_k = _apply_rotary_emb( - img_q, cos, sin, - is_neox_style=False), _apply_rotary_emb(img_k, - cos, - sin, - is_neox_style=False) + # Run distributed attention img_attn_output, txt_attn_output = self.attn(img_q, img_k, img_v, txt_q, - txt_k, txt_v) + txt_k, txt_v, freqs_cis=freqs_cis) attn_output = torch.cat((img_attn_output, txt_attn_output), dim=1).view(batch_size, seq_len, -1) # Process MLP activation @@ -593,7 +580,7 @@ def forward(self, # Get rotary embeddings freqs_cos, freqs_sin = get_rotary_pos_embed( - (tt * get_sp_world_size(), th, tw), self.hidden_size, + (tt, th, tw), self.hidden_size, self.num_attention_heads, self.rope_dim_list, self.rope_theta) freqs_cos = freqs_cos.to(x.device) freqs_sin = freqs_sin.to(x.device) @@ -608,6 +595,7 @@ def forward(self, vec = vec + self.guidance_in(guidance) # Embed image and text img = self.img_in(img) + img = sequence_model_parallel_shard(img, dim=1) txt = self.txt_in(txt, t) txt_seq_len = txt.shape[1] img_seq_len = img.shape[1] @@ -648,6 +636,7 @@ def forward(self, self.maybe_cache_states(img, original_img) # Final layer processing + img = sequence_model_parallel_all_gather(img, dim=1) img = self.final_layer(img, vec) # Unpatchify to get original shape img = unpatchify(img, tt, th, tw, self.patch_size, self.out_channels) diff --git a/fastvideo/models/dits/wanvideo.py b/fastvideo/models/dits/wanvideo.py index 2ca840622..b75d592fd 100644 --- a/fastvideo/models/dits/wanvideo.py +++ b/fastvideo/models/dits/wanvideo.py @@ -12,7 +12,7 @@ LocalAttention) from fastvideo.configs.models.dits import WanVideoConfig from fastvideo.configs.sample.wan import WanTeaCacheParams -from fastvideo.distributed.parallel_state import get_sp_world_size +from fastvideo.distributed.communication_op import sequence_model_parallel_all_gather, sequence_model_parallel_shard from fastvideo.forward_context import get_forward_context from fastvideo.layers.layernorm import (FP32LayerNorm, LayerNormScaleShift, RMSNorm, ScaleResidual, @@ -21,8 +21,7 @@ # from torch.nn import RMSNorm # TODO: RMSNorm .... from fastvideo.layers.mlp import MLP -from fastvideo.layers.rotary_embedding import (_apply_rotary_emb, - get_rotary_pos_embed) +from fastvideo.layers.rotary_embedding import get_rotary_pos_embed from fastvideo.layers.visual_embedding import (ModulateProjection, PatchEmbed, TimestepEmbedder) from fastvideo.logger import init_logger @@ -328,13 +327,7 @@ def forward( key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) - # Apply rotary embeddings - cos, sin = freqs_cis - query, key = _apply_rotary_emb(query, cos, sin, - is_neox_style=False), _apply_rotary_emb( - key, cos, sin, is_neox_style=False) - - attn_output, _ = self.attn1(query, key, value) + attn_output, _ = self.attn1(query, key, value, freqs_cis=freqs_cis) attn_output = attn_output.flatten(2) attn_output, _ = self.to_out(attn_output) attn_output = attn_output.squeeze(1) @@ -476,15 +469,11 @@ def forward( gate_compress = gate_compress.squeeze(1).unflatten( 2, (self.num_attention_heads, -1)) - # Apply rotary embeddings - cos, sin = freqs_cis - query, key = _apply_rotary_emb(query, cos, sin, - is_neox_style=False), _apply_rotary_emb( - key, cos, sin, is_neox_style=False) attn_output, _ = self.attn1(query, key, value, + freqs_cis=freqs_cis, gate_compress=gate_compress) attn_output = attn_output.flatten(2) attn_output, _ = self.to_out(attn_output) @@ -622,20 +611,20 @@ def forward(self, d = self.hidden_size // self.num_attention_heads rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)] freqs_cos, freqs_sin = get_rotary_pos_embed( - (post_patch_num_frames * get_sp_world_size(), post_patch_height, + (post_patch_num_frames, post_patch_height, post_patch_width), self.hidden_size, self.num_attention_heads, rope_dim_list, dtype=torch.float32 if current_platform.is_mps() else torch.float64, rope_theta=10000) - freqs_cos = freqs_cos.to(hidden_states.device) - freqs_sin = freqs_sin.to(hidden_states.device) - freqs_cis = (freqs_cos.float(), - freqs_sin.float()) if freqs_cos is not None else None - + freqs_cis = (freqs_cos.to(hidden_states.device).float(), + freqs_sin.to(hidden_states.device).float()) + + hidden_states = self.patch_embedding(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) + hidden_states = sequence_model_parallel_shard(hidden_states, dim=1) temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image) @@ -681,15 +670,15 @@ def forward(self, shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states, shift, scale) + hidden_states = sequence_model_parallel_all_gather(hidden_states, dim=1) hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1) hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) - + return output def maybe_cache_states(self, hidden_states: torch.Tensor, diff --git a/fastvideo/models/vaes/common.py b/fastvideo/models/vaes/common.py index 9730c8ff8..4218d8bbe 100644 --- a/fastvideo/models/vaes/common.py +++ b/fastvideo/models/vaes/common.py @@ -8,7 +8,7 @@ import numpy as np import torch import torch.distributed as dist -from diffusers.utils.torch_utils import randn_tensor +from fastvideo.utils import randn_tensor from fastvideo.configs.models import VAEConfig from fastvideo.distributed import get_sp_parallel_rank, get_sp_world_size diff --git a/fastvideo/pipelines/pipeline_registry.py b/fastvideo/pipelines/pipeline_registry.py index fcaa7edcb..1191a6e5d 100644 --- a/fastvideo/pipelines/pipeline_registry.py +++ b/fastvideo/pipelines/pipeline_registry.py @@ -21,6 +21,7 @@ "WanPipeline": "wan", "WanDMDPipeline": "wan", "WanImageToVideoPipeline": "wan", + "WanDMDPipeline": "wan", "StepVideoPipeline": "stepvideo", "HunyuanVideoPipeline": "hunyuan", } diff --git a/fastvideo/pipelines/stages/denoising.py b/fastvideo/pipelines/stages/denoising.py index c8fb76ce4..d46fbd2c0 100644 --- a/fastvideo/pipelines/stages/denoising.py +++ b/fastvideo/pipelines/stages/denoising.py @@ -15,10 +15,7 @@ from fastvideo.attention import get_attn_backend from fastvideo.configs.pipelines.base import STA_Mode -from fastvideo.distributed import (get_local_torch_device, get_sp_parallel_rank, - get_sp_world_size, get_world_group) -from fastvideo.distributed.communication_op import ( - sequence_model_parallel_all_gather) +from fastvideo.distributed import (get_local_torch_device, get_world_group) from fastvideo.fastvideo_args import FastVideoArgs from fastvideo.forward_context import set_forward_context from fastvideo.logger import init_logger @@ -31,6 +28,8 @@ from fastvideo.pipelines.stages.validators import VerificationResult from fastvideo.platforms import AttentionBackendEnum from fastvideo.utils import dict_to_3d_list +from fastvideo.utils import randn_tensor + try: from fastvideo.attention.backends.sliding_tile_attn import ( @@ -118,22 +117,7 @@ def forward( autocast_enabled = (target_dtype != torch.float32 ) and not fastvideo_args.disable_autocast - # Handle sequence parallelism if enabled - sp_world_size, rank_in_sp_group = get_sp_world_size( - ), get_sp_parallel_rank() - sp_group = sp_world_size > 1 - if sp_group: - latents = rearrange(batch.latents, - "b c (n t) h w -> b c n t h w", - n=sp_world_size).contiguous() - latents = latents[:, :, rank_in_sp_group, :, :, :] - batch.latents = latents - if batch.image_latent is not None: - image_latent = rearrange(batch.image_latent, - "b c (n t) h w -> b c n t h w", - n=sp_world_size).contiguous() - image_latent = image_latent[:, :, rank_in_sp_group, :, :, :] - batch.image_latent = image_latent + # Get timesteps and calculate warmup steps timesteps = batch.timesteps # TODO(will): remove this once we add input/output validation for stages @@ -339,10 +323,6 @@ def forward( and progress_bar is not None): progress_bar.update() - # Gather results if using sequence parallelism - if sp_group: - latents = sequence_model_parallel_all_gather(latents, dim=2) - # Update batch with final latents batch.latents = latents @@ -679,11 +659,7 @@ def forward( assert batch.latents is not None, "latents must be provided" latents = batch.latents # TODO(yongqi) hard code prepare latents - latents = torch.randn( - latents.permute(0, 2, 1, 3, 4).shape, - dtype=torch.bfloat16, - device="cuda", - generator=torch.Generator(device="cuda").manual_seed(42)) + latents = latents.permute(0, 2, 1, 3, 4) video_raw_latent_shape = latents.shape prompt_embeds = batch.prompt_embeds assert torch.isnan(prompt_embeds[0]).sum() == 0 @@ -692,22 +668,6 @@ def forward( dtype=torch.long, device=get_local_torch_device()) - # Handle sequence parallelism if enabled - sp_world_size, rank_in_sp_group = get_sp_world_size( - ), get_sp_parallel_rank() - sp_group = sp_world_size > 1 - if sp_group: - latents = rearrange(latents, - "b (n t) c h w -> b n t c h w", - n=sp_world_size).contiguous() - latents = latents[:, rank_in_sp_group, :, :, :, :] - if batch.image_latent is not None: - image_latent = rearrange(batch.image_latent, - "b c (n t) h w -> b c n t h w", - n=sp_world_size).contiguous() - - image_latent = image_latent[:, :, rank_in_sp_group, :, :, :] - batch.image_latent = image_latent # Run denoising loop with self.progress_bar(total=len(timesteps)) as progress_bar: @@ -799,14 +759,11 @@ def forward( if i < len(timesteps) - 1: next_timestep = timesteps[i + 1] * torch.ones( [1], dtype=torch.long, device=pred_video.device) - noise = torch.randn(video_raw_latent_shape, + noise = randn_tensor(video_raw_latent_shape, device=self.device, - dtype=pred_video.dtype) - if sp_group: - noise = rearrange(noise, - "b (n t) c h w -> b n t c h w", - n=sp_world_size).contiguous() - noise = noise[:, rank_in_sp_group, :, :, :, :] + dtype=pred_video.dtype, + generator=batch.generator) + latents = self.scheduler.add_noise( pred_video.flatten(0, 1), noise.flatten(0, 1), next_timestep).unflatten(0, pred_video.shape[:2]) @@ -820,11 +777,8 @@ def forward( and progress_bar is not None): progress_bar.update() - # Gather results if using sequence parallelism - if sp_group: - latents = sequence_model_parallel_all_gather(latents, dim=1) latents = latents.permute(0, 2, 1, 3, 4) # Update batch with final latents batch.latents = latents - return batch + return batch \ No newline at end of file diff --git a/fastvideo/training/distillation_pipeline.py b/fastvideo/training/distillation_pipeline.py index 086c75288..575831c3a 100644 --- a/fastvideo/training/distillation_pipeline.py +++ b/fastvideo/training/distillation_pipeline.py @@ -206,11 +206,7 @@ def _generator_forward(self, training_batch: TrainingBatch) -> torch.Tensor: noise = torch.randn(self.video_latent_shape, device=self.device, dtype=dtype) - if self.sp_world_size > 1: - noise = rearrange(noise, - "b (n t) c h w -> b n t c h w", - n=self.sp_world_size).contiguous() - noise = noise[:, self.rank_in_sp_group, :, :, :, :] + noisy_latent = self.noise_scheduler.add_noise(latents.flatten(0, 1), noise.flatten(0, 1), timestep).unflatten( @@ -248,13 +244,6 @@ def _generator_multi_step_simulation_forward( current_noise_latents = torch.randn(self.video_latent_shape, device=self.device, dtype=dtype) - if self.sp_world_size > 1: - current_noise_latents = rearrange( - current_noise_latents, - "b (n t) c h w -> b n t c h w", - n=self.sp_world_size).contiguous() - current_noise_latents = current_noise_latents[:, self. - rank_in_sp_group, :, :, :, :] # Only run intermediate steps if target_timestep_idx > 0 max_target_idx = len(self.denoising_step_list) - 1 @@ -286,11 +275,7 @@ def _generator_multi_step_simulation_forward( noise = torch.randn(self.video_latent_shape, device=self.device, dtype=pred_clean.dtype) - if self.sp_world_size > 1: - noise = rearrange(noise, - "b (n t) c h w -> b n t c h w", - n=self.sp_world_size).contiguous() - noise = noise[:, self.rank_in_sp_group, :, :, :, :] + current_noise_latents = self.noise_scheduler.add_noise( pred_clean.flatten(0, 1), noise.flatten(0, 1), next_timestep_tensor).unflatten(0, pred_clean.shape[:2]) @@ -334,11 +319,6 @@ def _dmd_forward(self, generator_pred_video: torch.Tensor, noise = torch.randn(self.video_latent_shape, device=self.device, dtype=generator_pred_video.dtype) - if self.sp_world_size > 1: - noise = rearrange(noise, - "b (n t) c h w -> b n t c h w", - n=self.sp_world_size).contiguous() - noise = noise[:, self.rank_in_sp_group, :, :, :, :] noisy_latent = self.noise_scheduler.add_noise( generator_pred_video.flatten(0, 1), noise.flatten(0, 1), @@ -441,12 +421,6 @@ def faker_score_forward( fake_score_noise = torch.randn(self.video_latent_shape, device=self.device, dtype=generator_pred_video.dtype) - if self.sp_world_size > 1: - fake_score_noise = rearrange(fake_score_noise, - "b (n t) c h w -> b n t c h w", - n=self.sp_world_size).contiguous() - fake_score_noise = fake_score_noise[:, self. - rank_in_sp_group, :, :, :, :] noisy_generator_pred_video = self.noise_scheduler.add_noise( generator_pred_video.flatten(0, 1), fake_score_noise.flatten(0, 1), @@ -515,13 +489,7 @@ def _prepare_dit_inputs(self, training_batch.latents = training_batch.latents.permute(0, 2, 1, 3, 4) self.video_latent_shape = training_batch.latents.shape - if self.sp_world_size > 1: - training_batch.latents = rearrange( - training_batch.latents, - "b (n t) c h w -> b n t c h w", - n=self.sp_world_size).contiguous() - training_batch.latents = training_batch.latents[:, self. - rank_in_sp_group, :, :, :, :] + self.video_latent_shape_sp = training_batch.latents.shape diff --git a/fastvideo/training/training_pipeline.py b/fastvideo/training/training_pipeline.py index 32e45c315..ce471b2f3 100644 --- a/fastvideo/training/training_pipeline.py +++ b/fastvideo/training/training_pipeline.py @@ -39,8 +39,7 @@ from fastvideo.training.training_utils import ( clip_grad_norm_while_handling_failing_dtensor_cases, compute_density_for_timestep_sampling, get_scheduler, get_sigmas, - load_checkpoint, normalize_dit_input, save_checkpoint, - shard_latents_across_sp) + load_checkpoint, normalize_dit_input, save_checkpoint) from fastvideo.utils import is_vsa_available, set_random_seed, shallow_asdict import wandb # isort: skip @@ -320,10 +319,12 @@ def _transformer_forward_and_compute_loss( # make sure no implicit broadcasting happens assert model_pred.shape == target.shape, f"model_pred.shape: {model_pred.shape}, target.shape: {target.shape}" - loss = (torch.mean((model_pred.float() - target.float())**2) / - self.training_args.gradient_accumulation_steps) - + + loss = torch.mean((model_pred.float() - target.float())**2) + loss /= self.training_args.gradient_accumulation_steps loss.backward() + + avg_loss = loss.detach().clone() # logger.info(f"rank: {self.rank}, avg_loss: {avg_loss.item()}", @@ -331,7 +332,6 @@ def _transformer_forward_and_compute_loss( world_group = get_world_group() world_group.all_reduce(avg_loss, op=dist.ReduceOp.AVG) training_batch.total_loss += avg_loss.item() - return training_batch def _clip_grad_norm(self, training_batch: TrainingBatch) -> TrainingBatch: @@ -367,17 +367,11 @@ def train_one_step(self, training_batch: TrainingBatch) -> TrainingBatch: training_batch = self._prepare_dit_inputs(training_batch) # Shard latents across sp groups - training_batch.latents = shard_latents_across_sp( - training_batch.latents, - num_latent_t=self.training_args.num_latent_t) + training_batch.latents = training_batch.latents[:, :, :self.training_args.num_latent_t] # shard noisy_model_input to match - training_batch.noisy_model_input = shard_latents_across_sp( - training_batch.noisy_model_input, - num_latent_t=self.training_args.num_latent_t) + training_batch.noisy_model_input = training_batch.noisy_model_input[:, :, :self.training_args.num_latent_t] # shard noise to match latents - training_batch.noise = shard_latents_across_sp( - training_batch.noise, - num_latent_t=self.training_args.num_latent_t) + training_batch.noise = training_batch.noise[:, :, :self.training_args.num_latent_t] training_batch = self._build_attention_metadata(training_batch) training_batch = self._build_input_kwargs(training_batch) @@ -609,7 +603,7 @@ def _log_validation(self, transformer, training_args, global_step) -> None: validation_dataloader = DataLoader(validation_dataset, batch_size=None, num_workers=0) - + return transformer.eval() validation_steps = training_args.validation_sampling_steps.split(",") diff --git a/fastvideo/training/training_utils.py b/fastvideo/training/training_utils.py index d0bf0961a..f3d0a8ada 100644 --- a/fastvideo/training/training_utils.py +++ b/fastvideo/training/training_utils.py @@ -535,17 +535,6 @@ def normalize_dit_input(model_type, latents, vae) -> torch.Tensor: raise NotImplementedError(f"model_type {model_type} not supported") -def shard_latents_across_sp(latents: torch.Tensor, - num_latent_t: int) -> torch.Tensor: - sp_world_size = get_sp_world_size() - rank_in_sp_group = get_sp_parallel_rank() - latents = latents[:, :, :num_latent_t] - if sp_world_size > 1: - latents = rearrange(latents, - "b c (n s) h w -> b c n s h w", - n=sp_world_size).contiguous() - latents = latents[:, :, rank_in_sp_group, :, :, :] - return latents def clip_grad_norm_while_handling_failing_dtensor_cases( diff --git a/fastvideo/utils.py b/fastvideo/utils.py index 9bbad98a4..dcc19962a 100644 --- a/fastvideo/utils.py +++ b/fastvideo/utils.py @@ -20,7 +20,7 @@ from dataclasses import dataclass, fields, is_dataclass from functools import lru_cache, partial, wraps from typing import Any, TypeVar, cast - +from typing import Tuple, List, Optional, Union import cloudpickle import filelock import torch @@ -812,3 +812,52 @@ def set_random_seed(seed: int) -> None: @lru_cache(maxsize=1) def is_vsa_available() -> bool: return importlib.util.find_spec("vsa") is not None + + +# copy from https://github.com/huggingface/diffusers/blob/v0.19.2/src/diffusers/utils/torch_utils.py#L36 +def randn_tensor( + shape: Union[Tuple, List], + generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, + device: Optional["torch.device"] = None, + dtype: Optional["torch.dtype"] = None, + layout: Optional["torch.layout"] = None, +)->torch.Tensor: + """A helper function to create random tensors on the desired `device` with the desired `dtype`. When + passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor + is always created on the CPU. + """ + # device on which tensor is created defaults to device + rand_device = device + batch_size = shape[0] + + layout = layout or torch.strided + device = device or torch.device("cpu") + + if generator is not None: + gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type + if gen_device_type != device.type and gen_device_type == "cpu": + rand_device = "cpu" + if device != "mps": + logger.info( + f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." + f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" slighly speed up this function by passing a generator that was created on the {device} device." + ) + elif gen_device_type != device.type and gen_device_type == "cuda": + raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") + + # make sure generator list of length 1 is treated like a non-list + if isinstance(generator, list) and len(generator) == 1: + generator = generator[0] + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + + return latents \ No newline at end of file diff --git a/scripts/inference/v1_inference_wan.sh b/scripts/inference/v1_inference_wan.sh index 42d030164..2ffe8c271 100755 --- a/scripts/inference/v1_inference_wan.sh +++ b/scripts/inference/v1_inference_wan.sh @@ -1,6 +1,6 @@ #!/bin/bash -num_gpus=1 +num_gpus=2 export FASTVIDEO_ATTENTION_BACKEND= export MODEL_BASE=Wan-AI/Wan2.1-T2V-1.3B-Diffusers # export MODEL_BASE=hunyuanvideo-community/HunyuanVideo @@ -15,8 +15,8 @@ fastvideo generate \ --text-encoder-cpu-offload True \ --pin-cpu-memory False \ --height 480 \ - --width 832 \ - --num-frames 77 \ + --width 848 \ + --num-frames 81 \ --num-inference-steps 50 \ --fps 16 \ --guidance-scale 6.0 \