-
Notifications
You must be signed in to change notification settings - Fork 192
Closed
Description
Motivation
The current code forces users to shorten video length for sharding compatibility:
def shard_latents_across_sp(latents: torch.Tensor,
num_latent_t: int) -> torch.Tensor:
sp_world_size = get_sp_world_size()
rank_in_sp_group = get_sp_parallel_rank()
latents = latents[:, :, :num_latent_t]
if sp_world_size > 1:
latents = rearrange(latents,
"b c (n s) h w -> b c n s h w",
n=sp_world_size).contiguous()
latents = latents[:, :, rank_in_sp_group, :, :, :]
return latents
Problem: When num_latent_t is not divisible by sp_world_size (which happens often in WAN with 21 temporal latent dim), the reshaping operation fails.
Proposed Solution:
- Preserve the entire video segment by padding the temporal dimension to make it divisible by sp_world_size
- Return a validity mask alongside the sharded latents
- Account for the mask in both training loss computation and inference to ignore padded regions
Metadata
Metadata
Assignees
Labels
No labels