Skip to content
Open

add moe #2467

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions paddleformers/nn/moe/allgather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import inspect
from typing import Callable, Dict, List, Optional, Tuple

import paddle
import paddle.distributed as dist
from paddle import framework, nn
from paddle.autograd import PyLayer
from paddle.distributed import fleet
from paddle.distributed.communication.group import Group, _get_global_group
from paddle.distributed.fleet.utils import recompute
from paddle.incubate.nn.functional import (
build_src_rank_and_local_expert_id,
expand_modality_expert_id,
moe_gate_dispatch_partial_nosoftmaxtopk,
)
from paddle.incubate.tensor.manipulation import async_offload
from paddleformers.peft.lora.lora_quantization_layers import QuantizationLoRALinear
from paddleformers.utils.log import logger

from paddleformers.transformers.ernie4_5.distributed.common_dist_utils import (
AllGatherGroupOp,
ReduceScatterGroupOp,
all_gather_group,
get_async_loader,
hack_offload_wait,
reduce_scatter_group,
)

from .utils import manual_backward

class AllGatherAsync(PyLayer):
"""
Perform async allgather.
"""

@staticmethod
def forward(ctx, input, *fn_args, group=None, fn=None, is_first_fwd=False):
"""Forward pass with integrated communication-computation overlap.

Args:
ctx: PyLayer context object
input (Tensor): Sharded input tensor [s/n, b, h]
*fn_args: Arguments for custom forward function
group: Model parallel process group
fn: Custom forward function to execute after communication
is_first_fwd: Flag indicating first forward pass in sequence

Returns:
tuple: (gathered_tensor, ...custom_forward_outputs)
"""
ctx.group = group
if dist.get_world_size(group) <= 1:
ctx.bwf, fn_out = manual_backward(fn, is_first_fwd, *fn_args)
return (input,) + fn_out
out, task = allgather_async(input, group=group)
ctx.bwf, fn_out = manual_backward(fn, is_first_fwd, *fn_args)
task and task.wait()
return (out,) + fn_out

@staticmethod
def backward(ctx, grad, *fn_out_grads):
"""Backward pass with gradient synchronization.

Args:
ctx: PyLayer context with stored communication group
grad (Tensor): Full gradient tensor [s, b, h]
*fn_out_grads: Gradients from custom forward outputs

Returns:
tuple: (scattered_grad, ...custom_arg_grads)
"""
if dist.get_world_size(ctx.group) <= 1:
fn_args_grads = ctx.bwf(*fn_out_grads)
return (grad,) + fn_args_grads

grad, task = reduce_scatter_async(grad, group=ctx.group)
fn_args_grads = ctx.bwf(*fn_out_grads)
task and task.wait()
return (grad,) + fn_args_grads

def allgather_async(input, group=None):
"""Perform asynchronous All-Gather operation for model parallelism.

Args:
input (Tensor): Local tensor to gather (shape: [N, ...])
group (ProcessGroup): Model parallel group (default: auto-detected)

Returns:
tuple: (output_tensor, communication_task)
output_tensor: Pre-allocated buffer with shape [N*K, ...] (K=group_size)
communication_task: Paddle communication task handle for synchronization
"""
if group is None:
hcg = fleet.get_hybrid_communicate_group()
group = hcg.get_model_parallel_group()
parallelism = group.nranks
if parallelism == 1:
return input.clone(), None
output_shape = input.shape
output_shape[0] = output_shape[0] * parallelism
output = paddle.empty(shape=output_shape, dtype=input.dtype)
task = dist.stream.all_gather(
output, input, group=group, use_calc_stream=False, sync_op=False
)
return output, task

def reduce_scatter_async(input, group=None):
"""Perform asynchronous reduce-scatter operation for distributed training.

Args:
input (Tensor): Local tensor to reduce (shape: [N*K, ...], N=group_size)
group (ProcessGroup): Communication group (default: model parallel group)

Returns:
tuple: (output_tensor, communication_task)
output_tensor: Scattered tensor portion with shape [K, ...]
communication_task: Handle for synchronizing the async operation
"""
if group is None:
hcg = fleet.get_hybrid_communicate_group()
group = hcg.get_model_parallel_group()
parallelism = group.nranks
if parallelism == 1:
return input.clone(), None
output_shape = input.shape
assert (
input.shape[0] % parallelism == 0
), f"Input sequence length {input.shape[0]} can't be divided exactly by sequence parallelism {parallelism}"
output_shape[0] = output_shape[0] // parallelism
output = paddle.empty(shape=output_shape, dtype=input.dtype)
task = dist.stream.reduce_scatter(
output,
input,
op=dist.ReduceOp.SUM,
group=group,
use_calc_stream=False,
sync_op=False,
)
return output, task

125 changes: 125 additions & 0 deletions paddleformers/nn/moe/alltoall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import paddle
import paddle.distributed as dist
import paddle.nn.functional as F
from paddle import Tensor, _C_ops, framework, nn
from paddle.autograd import PyLayer
from paddle.distributed import fleet
from paddle.distributed.communication import stream
from paddle.distributed.communication.group import Group
from paddle.distributed.fleet.utils import recompute
from paddle.incubate.nn.functional import moe_combine, moe_gate_dispatch
from paddleformers.utils.log import logger
from paddleformers.transformers.ernie4_5.sequence_parallel_utils import ScatterOp

from .utils import manual_backward

class AlltoAll(PyLayer):
"""
Custom PyLayer for All-to-All communication with backward pass.
"""
@staticmethod
def forward(ctx, x, group, sync_op=True):
"""
Perform All-to-All communication in the group.

Args:
x: Input tensor
group: Communication group
sync_op: Whether to perform synchronous operation

Returns:
Tensor: Output tensor
"""
ctx.group = group
if dist.get_world_size(group) <= 1:
return x
output = paddle.empty_like(x)
output.stop_gradient = False
task = stream.alltoall_single(
output, x, None, None, group, sync_op=sync_op, use_calc_stream=sync_op
)
if not sync_op:
return output, task
else:
return output

@staticmethod
def backward(ctx, *dx):
"""
Backward pass for All-to-All communication.

Args:
dx: Gradient tensor

Returns:
Tensor: Gradient after backward All-to-All
"""
return AlltoAll.apply(*dx, group=ctx.group)

class AlltoAllAsync(PyLayer):
"""
Custom PyLayer for asynchronous All-to-All communication.
"""
@staticmethod
def forward(ctx, x, *fn_args, group=None, fn=None, is_first_fwd=False):
"""
Asynchronous All-to-All communication with function execution.

Args:
x: Input tensor
fn_args: Arguments for the function
group: Communication group
fn: Function to execute
is_first_fwd: Whether this is the first forward pass

Returns:
tuple: (output tensor, function outputs)
"""
assert fn is not None, "use AlltoAll no async"
ctx.group = group
if dist.get_world_size(group) <= 1:
ctx.bwf, fn_out = manual_backward(fn, is_first_fwd, *fn_args)
return (x,) + fn_out
x_out = paddle.empty_like(x)
x_out.stop_gradient = False
task = stream.alltoall_single(
x_out,
x,
None,
None,
group,
sync_op=False,
)
ctx.bwf, fn_out = manual_backward(fn, is_first_fwd, *fn_args)
task.wait()
return (x_out,) + fn_out

@staticmethod
def backward(ctx, dx_out, *fn_out_grads):
"""
Backward pass for asynchronous All-to-All.

Args:
dx_out: Gradient of output
fn_out_grads: Gradients of function outputs

Returns:
tuple: (gradient tensor, function argument gradients)
"""
if dist.get_world_size(ctx.group) <= 1:
fn_args_grads = ctx.bwf(*fn_out_grads)
return (dx_out,) + fn_args_grads

dx = paddle.empty_like(dx_out)
dx.stop_gradient = False
task = stream.alltoall_single(
dx,
dx_out,
None,
None,
ctx.group,
sync_op=False,
)
fn_args_grads = ctx.bwf(*fn_out_grads)
task.wait()
return (dx,) + fn_args_grads
Loading
Loading