Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions megatron/core/tensor_parallel/mappings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

import torch

Expand Down Expand Up @@ -419,11 +419,12 @@ def backward(ctx, grad_output):

class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx, group, input, output_split_sizes, input_split_sizes):
def forward(ctx, group, input, output_split_sizes, input_split_sizes, use_nccl_stream=False):
"""Forward function."""
ctx.group = group
ctx.output_split_sizes = output_split_sizes
ctx.input_split_sizes = input_split_sizes
ctx.use_nccl_stream = use_nccl_stream

world_size = group.size()
# Bypass the function if we are using only 1 GPU.
Expand All @@ -441,21 +442,39 @@ def forward(ctx, group, input, output_split_sizes, input_split_sizes):
dtype=input.dtype,
device=torch.cuda.current_device(),
)
torch.distributed.all_to_all_single(
output,
input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
)
if use_nccl_stream:
handle = torch.distributed.all_to_all_single(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the differene between if and else paths?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On Blackwell, we want to enable CUDA_DEVICE_MAX_CONNECTIONS>1 and use stream priority for COMM/GEMM overlapping(COMM with higher priority can preempt GEMM SMs with lower priority). But PyTorch will launch NCCL kernel on default stream since a refactor pytorch/pytorch#148590, which make the nccl stream priority not work.(We can't change the stream priority of default stream)

So we need to set use_async_op=True to put on a separate NCCL stream.

output,
input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=True,
)
handle.wait()
else:
torch.distributed.all_to_all_single(
output,
input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
)
return output

@staticmethod
def backward(ctx, *grad_output):
"""Backward function."""
return (
None,
_AllToAll.apply(ctx.group, *grad_output, ctx.input_split_sizes, ctx.output_split_sizes),
_AllToAll.apply(
ctx.group,
*grad_output,
ctx.input_split_sizes,
ctx.output_split_sizes,
ctx.use_nccl_stream,
),
None,
None,
None,
)
Expand Down Expand Up @@ -532,10 +551,12 @@ def reduce_scatter_last_dim_to_tensor_parallel_region(input_, group=None):
return _ReduceScatterToTensorParallelRegion.apply(input_, group)


def all_to_all(group, input_, output_split_sizes_=None, input_split_sizes=None):
def all_to_all(
group, input_, output_split_sizes_=None, input_split_sizes=None, use_nccl_stream=False
):
"""Wrapper for autograd function"""
assert group is not None, "group should not be None"
return _AllToAll.apply(group, input_, output_split_sizes_, input_split_sizes)
return _AllToAll.apply(group, input_, output_split_sizes_, input_split_sizes, use_nccl_stream)


def all_to_all_sp2hp(input_, group=None):
Expand Down
106 changes: 89 additions & 17 deletions megatron/core/transformer/moe/shared_experts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

import warnings
from copy import deepcopy
from enum import Enum
from functools import wraps
from typing import Optional

import torch
Expand All @@ -27,6 +29,63 @@
)


class SharedExpertState(Enum):
"""State machine states for SharedExpertMLP overlapped forward pass."""

IDLE = 0
PRE_FORWARD_COMM_DONE = 1
FC1_FORWARD_DONE = 2
FC2_FORWARD_DONE = 3
POST_FORWARD_COMM_DONE = 4


def overlap_state_check(required_state: "SharedExpertState", next_state: "SharedExpertState"):
"""
Decorator to validate overlap state and cached variables before method execution,
and update state after method execution.

Args:
required_state: The expected SharedExpertState before this method runs.
next_state: The SharedExpertState to transition to after method execution.
"""

def decorator(method):
@wraps(method)
def wrapper(self, *args, **kwargs):
# Check overlap is enabled
assert (
self.config.moe_shared_expert_overlap
), f"{method.__name__} requires --moe-shared-expert-overlap to be set"
# Check state machine
assert self._overlap_state == required_state, (
f"{method.__name__} must be called from {required_state.name} state, "
f"but current state is {self._overlap_state.name}"
)
# Execute method
result = method(self, *args, **kwargs)
# Update state after method execution
self._overlap_state = next_state
return result

return wrapper

return decorator


class _BackwardStreamWait(torch.autograd.Function):
@staticmethod
def forward(ctx, input, stream):
"""forward"""
ctx.stream = stream
return input

@staticmethod
def backward(ctx, grad_output):
"""backward with stream wait"""
ctx.stream.wait_stream(torch.cuda.current_stream())
return grad_output, None


class SharedExpertMLP(MLP):
"""
MLP layer for Shared Experts.
Expand Down Expand Up @@ -117,8 +176,11 @@ def __init__(
self.cached_output = None
self.gate_score = None

if self.stream is None:
self.stream = torch.cuda.Stream()
# State machine to ensure correct calling order of overlapped forward methods
self._overlap_state = SharedExpertState.IDLE

if SharedExpertMLP.stream is None:
SharedExpertMLP.stream = torch.cuda.Stream()

def forward(self, hidden_states):
"""Forward function"""
Expand All @@ -145,15 +207,19 @@ def sharded_state_dict(
sharded_state_dict.update(sub_sd)
return sharded_state_dict

def pre_forward_comm(self, input):
def wait_current_stream(self):
"""Wait for the current stream to complete."""
self.stream.wait_stream(torch.cuda.current_stream())

@overlap_state_check(SharedExpertState.IDLE, SharedExpertState.PRE_FORWARD_COMM_DONE)
def pre_forward_comm(self, input, wait_current_stream=True):
"""
All Gather for SP before forward.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert self.config.moe_shared_expert_overlap
assert self.cached_output is None
self.stream.wait_stream(torch.cuda.current_stream())
if wait_current_stream:
self.wait_current_stream()
with torch.cuda.stream(self.stream):
if self.use_shared_expert_gate:
logits = torch.nn.functional.linear(input, self.gate_weight)
Expand All @@ -166,16 +232,15 @@ def pre_forward_comm(self, input):
self.cached_fc1_input = copy_to_tensor_model_parallel_region(input)
set_tensor_grad_fn_sequence_sr(self.cached_fc1_input, torch.iinfo(torch.int).max)

@overlap_state_check(
SharedExpertState.PRE_FORWARD_COMM_DONE, SharedExpertState.FC1_FORWARD_DONE
)
def linear_fc1_forward_and_act(self, overlapped_comm_output=None):
"""
Do Linear FC1 and activation function forward.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert self.config.moe_shared_expert_overlap
assert self.cached_fc1_input is not None
if overlapped_comm_output is not None:
set_tensor_grad_fn_sequence_sr(overlapped_comm_output, torch.iinfo(torch.int).max)
with torch.cuda.stream(self.stream):
# [s, b, 4 * h/p]
intermediate_parallel, bias_parallel = self.linear_fc1(self.cached_fc1_input)
Expand Down Expand Up @@ -216,30 +281,38 @@ def glu(x):
intermediate_parallel = self.activation_func(intermediate_parallel)

self.cached_fc2_input = intermediate_parallel
# Tensor sequence number is used to control the backward order.
# Decrease the sequence number of the expert output to make the comm launched first
# in the backward order.
if overlapped_comm_output is not None and overlapped_comm_output.grad_fn is not None:
target_sequence_nr = overlapped_comm_output.grad_fn._sequence_nr() - 1
set_tensor_grad_fn_sequence_sr(intermediate_parallel, target_sequence_nr)
# Make sure the shared expert fc1 backward is launched after the routed fc1 backward
self.cached_fc2_input = _BackwardStreamWait.apply(intermediate_parallel, self.stream)

@overlap_state_check(SharedExpertState.FC1_FORWARD_DONE, SharedExpertState.FC2_FORWARD_DONE)
def linear_fc2_forward(self, overlapped_comm_output=None):
"""
Do Linear FC2 forward.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert self.config.moe_shared_expert_overlap
assert self.cached_fc2_input is not None
if overlapped_comm_output is not None:
set_tensor_grad_fn_sequence_sr(overlapped_comm_output, torch.iinfo(torch.int).max)
with torch.cuda.stream(self.stream):
# [s, b, h]
self.cached_fc2_output, _ = self.linear_fc2(self.cached_fc2_input)
self.cached_fc2_input = None

@overlap_state_check(
SharedExpertState.FC2_FORWARD_DONE, SharedExpertState.POST_FORWARD_COMM_DONE
)
def post_forward_comm(self):
"""
Reduce scatter for SP after forward.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert self.config.moe_shared_expert_overlap
assert self.cached_fc2_output is not None
with torch.cuda.stream(self.stream):
if self.config.sequence_parallel:
self.cached_output = reduce_scatter_to_sequence_parallel_region(
Expand All @@ -252,14 +325,13 @@ def post_forward_comm(self):
self.cached_fc2_output = None
set_tensor_grad_fn_sequence_sr(self.cached_output, torch.iinfo(torch.int).max)

@overlap_state_check(SharedExpertState.POST_FORWARD_COMM_DONE, SharedExpertState.IDLE)
def get_output(self):
"""
Gets the module forward output.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert self.config.moe_shared_expert_overlap
assert self.cached_output is not None
with torch.cuda.stream(self.stream):
if self.use_shared_expert_gate:
assert self.gate_score is not None
Expand Down
Loading
Loading