Skip to content

[Feature] Enable any frame length for SP parallel #676

@MartinPernus

Description

@MartinPernus

Motivation

The current code forces users to shorten video length for sharding compatibility:

def shard_latents_across_sp(latents: torch.Tensor,
                            num_latent_t: int) -> torch.Tensor:
    sp_world_size = get_sp_world_size()
    rank_in_sp_group = get_sp_parallel_rank()
    latents = latents[:, :, :num_latent_t]
    if sp_world_size > 1:
        latents = rearrange(latents,
                            "b c (n s) h w -> b c n s h w",
                            n=sp_world_size).contiguous()
        latents = latents[:, :, rank_in_sp_group, :, :, :]
    return latents

Problem: When num_latent_t is not divisible by sp_world_size (which happens often in WAN with 21 temporal latent dim), the reshaping operation fails.

Proposed Solution:

  • Preserve the entire video segment by padding the temporal dimension to make it divisible by sp_world_size
  • Return a validity mask alongside the sharded latents
  • Account for the mask in both training loss computation and inference to ignore padded regions

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions