Skip to content

Context Parallel support#381

Merged
a-r-r-o-w merged 38 commits intomainfrom
context-parallel
May 13, 2025
Merged

Context Parallel support#381
a-r-r-o-w merged 38 commits intomainfrom
context-parallel

Conversation

@a-r-r-o-w
Copy link
Contributor

@a-r-r-o-w a-r-r-o-w commented Apr 22, 2025

Context Parallel

References and reading material:

There are three steps to enabling context parallelism with any model:

  • Defining the context parallel plan: This is a dictionary that mentions what tensors to split and gather across CP region at different layers in the model
  • Applying the CP plan with apply_context_parallel function: This registers the necessary hooks to split and gather tensors at the right places in the model without having to manually modify the model code.
  • Running model under the attention_provider context manager

For a quick example, refer to the inference example below.

The CP plan is a dictionary that maps the name of the module to a list of CPInput or CPOutput objects. The keys in the dictionary are the names of the internal modules in the model, and the values are dictionaries that map a parameter identifier (either as an argument index or keyword argument as used in the forward method) to a CPInput or CPOutput object. The CPInput object specifies the input tensor to be split, and the CPOutput object specifies the output tensor to be gathered.

class ParamId:
    name: Optional[str] = None
    index: Optional[int] = None

class CPInput:
    split_dim: int
    expected_dims: Optional[int] = None
    split_output: bool = False

class CPOutput:
    gather_dim: int
    expected_dims: Optional[int] = None
  • The split_dim and gather_dim parameters specify the dimension along which to split or gather the tensor. When using CP with native scaled dot product attention from pytorch, the tensor shape is [B, N, S, D], so the split_dim and gather_dim parameters should be set to 2 as it is the sequence dimension.

  • The expected_dims parameter is an optional parameter that is used for sanity checking if the tensor contains the expected number of dimensions.

  • By default, CPInput's are split in a pre-forward hook and CPOutput's are gathered in a post-forward hook. If you want to split the output of a module, you can set the split_output parameter to True. This will split the output tensor in the post-forward hook instead of the pre-forward hook.

  • Attention providers supported for training with CP: flash, _native_cudnn, _native_efficient, _native_flash

  • Attention providers supported for inference with CP: flash, _native_cudnn, _native_efficient, _native_flash

Training

To enable training with context parallelism, you need to make sure a suitable CP plan is registered for the model you are using and launch training with --cp_degree N, where N > 1. For models supported in finetrainers, this is internally done in the transformer metadata file. For custom models, make sure to pass the plan argument to the apply_context_parallel function.

Currently supported models include: CogVideoX, CogView4, Flux, Wan 2.1. Support for more models and attention providers is in progress.

Inference

The following example shows how to run context parallel inference. For more examples and ready-to-use inference scripts, check out the examples/inference folder.

Example
import torch
import torch.distributed as dist
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.utils import export_to_video

from finetrainers._metadata import ParamId, CPInput, CPOutput
from finetrainers.parallel.ptd import apply_context_parallel
from finetrainers.models.attention_dispatch import attention_provider, attention_dispatch

torch.nn.functional.scaled_dot_product_attention = attention_dispatch


def apply_compile(model: torch.nn.Module, compile_scope: str) -> torch.nn.Module:
    r"""Apply torch.compile to a model or its submodules if not already compiled."""
    if getattr(model, "_torch_compiled", False):
        return model  # Already compiled

    if compile_scope == "full":
        model = torch.compile(model)
        setattr(model, "_torch_compiled", True)
    elif compile_scope == "regional":
        if isinstance(model, torch.nn.ModuleList):
            for name, module in model.named_children():
                if not getattr(module, "_torch_compiled", False):
                    compiled_module = torch.compile(module, mode="max-autotune-no-cudagraphs", fullgraph=False, dynamic=False)
                    setattr(compiled_module, "_torch_compiled", True)
                    model.register_module(name, compiled_module)
        else:
            for name, module in model.named_children():
                apply_compile(module, compile_scope)
    else:
        raise ValueError(f"Unknown compile mode: {compile_scope}. Use 'full' or 'regional'.")

    return model


torch.manual_seed(0)
dist.init_process_group("nccl")
rank, world_size = dist.get_rank(), dist.get_world_size()
torch.cuda.set_device(rank)
cp_mesh = dist.device_mesh.init_device_mesh("cuda", [world_size], mesh_dim_names=["cp"])

cp_plan = {
    "rope": {
        ParamId(index=0): CPInput(2, 4, split_output=True),
    },
    "blocks.*": {
        ParamId("encoder_hidden_states", 1): CPInput(1, 3),
    },
    "blocks.0": {
        ParamId("hidden_states", 0): CPInput(1, 3),
    },
    "proj_out": [CPOutput(1, 3)],
}

try:
    model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
    vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
    pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    apply_context_parallel(pipe.transformer, mesh=cp_mesh, plan=cp_plan)

    apply_compile(pipe.transformer, compile_scope="regional")

    prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
    negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

    with torch.no_grad():
        prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
            prompt=prompt, negative_prompt=negative_prompt, device="cuda",
        )
    
    attention_backend = "_native_flash"
    generator = torch.Generator().manual_seed(0)
    
    # Warmup for compilation
    with attention_provider(attention_backend, mesh=cp_mesh, convert_to_fp32=True, rotate_method="alltoall"):
        latents = pipe(
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            height=480,
            width=832,
            num_frames=81,
            num_inference_steps=2,
            guidance_scale=5.0,
            output_type="latent",
            generator=generator,
        ).frames[0]

    # Inference
    with attention_provider(attention_backend, mesh=cp_mesh, convert_to_fp32=True, rotate_method="allgather"):
        latents = pipe(
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            height=480,
            width=832,
            num_frames=81,
            guidance_scale=5.0,
            num_inference_steps=30,
            output_type="latent",
            generator=generator,
        ).frames[0]
    
    with torch.no_grad():
        latents = latents.to(pipe.vae.dtype)
        latents_mean = (
            torch.tensor(pipe.vae.config.latents_mean)
            .view(1, pipe.vae.config.z_dim, 1, 1, 1)
            .to(latents.device, latents.dtype)
        )
        latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1).to(
            latents.device, latents.dtype
        )
        latents = latents / latents_std + latents_mean
        video = pipe.vae.decode(latents, return_dict=False)[0]
        video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
    
    if rank == 0:
        export_to_video(video, "output.mp4", fps=16)
finally:
    dist.destroy_process_group()

Benchmarks

TODO: Will be updated in a future PR

@a-r-r-o-w a-r-r-o-w merged commit 2494f41 into main May 13, 2025
1 check passed
@a-r-r-o-w a-r-r-o-w deleted the context-parallel branch May 13, 2025 15:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant