Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions examples/multimodal/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from megatron.core.enums import ModelType
from megatron.core.models.multimodal import context_parallel
from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, LLaVAModel
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.context_parallel import DefaultContextParallelHandler
from megatron.core.parallel_state import (
get_tensor_model_parallel_rank,
get_pipeline_model_parallel_world_size,
Expand All @@ -42,15 +42,15 @@ def get_batch(data_iterator, image_token_index, img_seq_len):
attention_mask = None
position_ids = None
num_tiles = None
packed_seq_params = None
cp_handler = None

args = get_args()

# Dataloader doesn't run on the middle stages in a pipeline parallel model.
pp_size = get_pipeline_model_parallel_world_size()
if not is_first_or_last_stage(pp_size):
# Note these are all set to None above.
return tokens, labels, loss_mask, attention_mask, position_ids, imgs, num_tiles, packed_seq_params
return tokens, labels, loss_mask, attention_mask, position_ids, imgs, num_tiles, cp_handler

# Broadcast data.
torch.cuda.nvtx.range_push("get_data")
Expand Down Expand Up @@ -94,7 +94,7 @@ def get_batch(data_iterator, image_token_index, img_seq_len):
cu_lengths = cu_lengths[0]
max_lengths = max_lengths[0]

packed_seq_params = PackedSeqParams(
cp_handler = DefaultContextParallelHandler(
qkv_format="thd",
cu_seqlens_q=cu_lengths,
cu_seqlens_kv=cu_lengths,
Expand Down Expand Up @@ -135,7 +135,7 @@ def get_batch(data_iterator, image_token_index, img_seq_len):
tokens, position_ids, labels, loss_mask = [torch.nn.functional.pad(item, (0, mp_padding_needed)) for item in (tokens, position_ids, labels, loss_mask)]

# Get PackedSeqParams that indicate the amount of padding for TransformerEngine.
packed_seq_params = context_parallel.get_packed_seq_params(tokens, num_image_embeddings, mp_padding_needed, args.context_parallel_size, True)
cp_handler = context_parallel.get_cp_handler(tokens, num_image_embeddings, mp_padding_needed, args.context_parallel_size, True)

return (
tokens,
Expand All @@ -145,7 +145,7 @@ def get_batch(data_iterator, image_token_index, img_seq_len):
position_ids,
imgs,
num_tiles,
packed_seq_params,
cp_handler,
)


Expand Down Expand Up @@ -274,7 +274,7 @@ def forward_step(data_iterator, model: LLaVAModel):
position_ids,
images,
num_image_tiles,
packed_seq_params,
cp_handler,
) = get_batch(data_iterator, model.module.module.image_token_index, model.module.module.img_seq_len)
timers('batch-generator').stop()

Expand All @@ -286,7 +286,7 @@ def forward_step(data_iterator, model: LLaVAModel):
labels,
loss_mask,
num_image_tiles=num_image_tiles,
packed_seq_params=packed_seq_params,
cp_handler=cp_handler,
)
args = get_args()
if args.use_loss_scaling:
Expand Down
35 changes: 35 additions & 0 deletions megatron/core/context_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Literal, Optional

from .backend import (
ContextParallelHandler,
DefaultContextParallelHandler,
MagiAttnContextParallelHandler,
TEDynamicContextParallelHandler,
)


def get_cp_handler_cls(
backend: Optional[Literal["transformer_engine", "local", "magi"]] = None,
cp_comm_type: Optional[str] = None,
) -> type[ContextParallelHandler]:
"""
Factory function to get the appropriate Context Parallel Handler class based on the backend.

Args:
backend: The attention backend to use ('transformer_engine', 'local', or 'magi').
cp_comm_type: Optional communication type identifier (unused in current logic).

Returns:
The class definition of the appropriate ContextParallelHandler.

Raises:
ValueError: If an unsupported backend is provided.
"""
if backend == "transformer_engine" or backend == "local":
return DefaultContextParallelHandler
elif backend == "magi":
return MagiAttnContextParallelHandler
elif backend == "transformer_engine_dynamic":
return TEDynamicContextParallelHandler
else:
raise ValueError(f"Unsupported attention backend for context parallel: {backend}")
10 changes: 10 additions & 0 deletions megatron/core/context_parallel/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from ._base import ContextParallelHandler
from .default import DefaultContextParallelHandler, TEDynamicContextParallelHandler
from .magi import MagiAttnContextParallelHandler

__all__ = [
"ContextParallelHandler",
"DefaultContextParallelHandler",
"MagiAttnContextParallelHandler",
"TEDynamicContextParallelHandler",
]
101 changes: 101 additions & 0 deletions megatron/core/context_parallel/backend/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple

import torch
import torch.distributed as dist
from torch import Tensor, nn

from megatron.core import parallel_state

if TYPE_CHECKING:
from megatron.core.transformer import TransformerConfig


@dataclass
class ContextParallelHandler(ABC):
"""
Abstract base class for Context Parallel (CP) handlers.
Manages distribution, combination, and manipulation of tensors across context parallel ranks.
"""

# Legacy parameters for PackedSeqParams
qkv_format: Optional[Literal["sbhd", "bshd", "thd"]] = None
cp_group: Optional[dist.ProcessGroup] = None

cu_seqlens_q: Optional[Tensor] = None
cu_seqlens_kv: Optional[Tensor] = None
cu_seqlens_q_padded: Optional[Tensor] = None
cu_seqlens_kv_padded: Optional[Tensor] = None
max_seqlen_q: Optional[int] = None
max_seqlen_kv: Optional[int] = None

# 在dcp中使用
local_cp_size: Optional[int] = None

# 在DefaultContextParallelHandler中使用
seqlens_q_list: Optional[List[int]] = None
seqlens_kv_list: Optional[List[int]] = None
seqlens_q_padded: Optional[torch.Tensor] = None
seqlens_kv_padded: Optional[torch.Tensor] = None
# Lists containing flattened [actual_len, padded_len] pairs
seqlens_q_with_padded_list: Optional[List[int]] = None
seqlens_kv_with_padded_list: Optional[List[int]] = None
total_seqlen_padded_q: Optional[int] = None
total_seqlen_padded_kv: Optional[int] = None

_post_initialized: bool = False
_cp_size: int = 1

def __post_init__(self):
if self.qkv_format is None:
self.qkv_format = "sbhd"

if self.cp_group is None:
self.cp_group = parallel_state.get_context_parallel_group(check_initialized=False)

if self.cp_group is not None:
self._cp_size = dist.get_world_size(self.cp_group)

@abstractmethod
def dispatch(
self, seq_dim: int, tensor: torch.Tensor, *args: Any, **kwargs: Any
) -> torch.Tensor:
"""Splits and dispatches the tensor to the appropriate CP rank during forward pass."""
pass

@abstractmethod
def combine(
self, seq_dim: int, tensor: torch.Tensor, *args: Any, **kwargs: Any
) -> torch.Tensor:
"""Combines tensors from different CP ranks (gather) during forward pass."""
pass

@abstractmethod
def apply_rotary_pos_emb(
self,
tensor: torch.Tensor,
freq: torch.Tensor,
config: "TransformerConfig",
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
"""Applies Rotary Positional Embeddings considering context parallelism."""
pass

@abstractmethod
def get_emb_on_this_cp_rank(self, emb: torch.Tensor) -> torch.Tensor:
"""Retrieves the slice of embeddings belonging to the current CP rank."""
pass

@abstractmethod
def roll_tensor(
self, tensor: torch.Tensor, shifts: int = -1, dims: int = -1
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Rolls the tensor elements along a dimension, handling communication across CP ranks."""
pass

@abstractmethod
def core_attn(self, attn_mod: nn.Module, *args: Any, **kwargs: Any) -> Any:
"""Executes the core attention logic using this handler."""
pass
Loading