diff --git a/paddleformers/nn/moe/allgather.py b/paddleformers/nn/moe/allgather.py new file mode 100644 index 00000000000..2361c3fdbc6 --- /dev/null +++ b/paddleformers/nn/moe/allgather.py @@ -0,0 +1,140 @@ +import inspect +from typing import Callable, Dict, List, Optional, Tuple + +import paddle +import paddle.distributed as dist +from paddle import framework, nn +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.distributed.communication.group import Group, _get_global_group +from paddle.distributed.fleet.utils import recompute +from paddle.incubate.nn.functional import ( + build_src_rank_and_local_expert_id, + expand_modality_expert_id, + moe_gate_dispatch_partial_nosoftmaxtopk, +) +from paddle.incubate.tensor.manipulation import async_offload +from paddleformers.peft.lora.lora_quantization_layers import QuantizationLoRALinear +from paddleformers.utils.log import logger + +from paddleformers.transformers.ernie4_5.distributed.common_dist_utils import ( + AllGatherGroupOp, + ReduceScatterGroupOp, + all_gather_group, + get_async_loader, + hack_offload_wait, + reduce_scatter_group, +) + +from .utils import manual_backward + +class AllGatherAsync(PyLayer): + """ + Perform async allgather. + """ + + @staticmethod + def forward(ctx, input, *fn_args, group=None, fn=None, is_first_fwd=False): + """Forward pass with integrated communication-computation overlap. + + Args: + ctx: PyLayer context object + input (Tensor): Sharded input tensor [s/n, b, h] + *fn_args: Arguments for custom forward function + group: Model parallel process group + fn: Custom forward function to execute after communication + is_first_fwd: Flag indicating first forward pass in sequence + + Returns: + tuple: (gathered_tensor, ...custom_forward_outputs) + """ + ctx.group = group + if dist.get_world_size(group) <= 1: + ctx.bwf, fn_out = manual_backward(fn, is_first_fwd, *fn_args) + return (input,) + fn_out + out, task = allgather_async(input, group=group) + ctx.bwf, fn_out = manual_backward(fn, is_first_fwd, *fn_args) + task and task.wait() + return (out,) + fn_out + + @staticmethod + def backward(ctx, grad, *fn_out_grads): + """Backward pass with gradient synchronization. + + Args: + ctx: PyLayer context with stored communication group + grad (Tensor): Full gradient tensor [s, b, h] + *fn_out_grads: Gradients from custom forward outputs + + Returns: + tuple: (scattered_grad, ...custom_arg_grads) + """ + if dist.get_world_size(ctx.group) <= 1: + fn_args_grads = ctx.bwf(*fn_out_grads) + return (grad,) + fn_args_grads + + grad, task = reduce_scatter_async(grad, group=ctx.group) + fn_args_grads = ctx.bwf(*fn_out_grads) + task and task.wait() + return (grad,) + fn_args_grads + +def allgather_async(input, group=None): + """Perform asynchronous All-Gather operation for model parallelism. + + Args: + input (Tensor): Local tensor to gather (shape: [N, ...]) + group (ProcessGroup): Model parallel group (default: auto-detected) + + Returns: + tuple: (output_tensor, communication_task) + output_tensor: Pre-allocated buffer with shape [N*K, ...] (K=group_size) + communication_task: Paddle communication task handle for synchronization + """ + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + if parallelism == 1: + return input.clone(), None + output_shape = input.shape + output_shape[0] = output_shape[0] * parallelism + output = paddle.empty(shape=output_shape, dtype=input.dtype) + task = dist.stream.all_gather( + output, input, group=group, use_calc_stream=False, sync_op=False + ) + return output, task + +def reduce_scatter_async(input, group=None): + """Perform asynchronous reduce-scatter operation for distributed training. + + Args: + input (Tensor): Local tensor to reduce (shape: [N*K, ...], N=group_size) + group (ProcessGroup): Communication group (default: model parallel group) + + Returns: + tuple: (output_tensor, communication_task) + output_tensor: Scattered tensor portion with shape [K, ...] + communication_task: Handle for synchronizing the async operation + """ + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + if parallelism == 1: + return input.clone(), None + output_shape = input.shape + assert ( + input.shape[0] % parallelism == 0 + ), f"Input sequence length {input.shape[0]} can't be divided exactly by sequence parallelism {parallelism}" + output_shape[0] = output_shape[0] // parallelism + output = paddle.empty(shape=output_shape, dtype=input.dtype) + task = dist.stream.reduce_scatter( + output, + input, + op=dist.ReduceOp.SUM, + group=group, + use_calc_stream=False, + sync_op=False, + ) + return output, task + diff --git a/paddleformers/nn/moe/alltoall.py b/paddleformers/nn/moe/alltoall.py new file mode 100644 index 00000000000..80e9e800bc5 --- /dev/null +++ b/paddleformers/nn/moe/alltoall.py @@ -0,0 +1,125 @@ +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import Tensor, _C_ops, framework, nn +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.distributed.communication import stream +from paddle.distributed.communication.group import Group +from paddle.distributed.fleet.utils import recompute +from paddle.incubate.nn.functional import moe_combine, moe_gate_dispatch +from paddleformers.utils.log import logger +from paddleformers.transformers.ernie4_5.sequence_parallel_utils import ScatterOp + +from .utils import manual_backward + +class AlltoAll(PyLayer): + """ + Custom PyLayer for All-to-All communication with backward pass. + """ + @staticmethod + def forward(ctx, x, group, sync_op=True): + """ + Perform All-to-All communication in the group. + + Args: + x: Input tensor + group: Communication group + sync_op: Whether to perform synchronous operation + + Returns: + Tensor: Output tensor + """ + ctx.group = group + if dist.get_world_size(group) <= 1: + return x + output = paddle.empty_like(x) + output.stop_gradient = False + task = stream.alltoall_single( + output, x, None, None, group, sync_op=sync_op, use_calc_stream=sync_op + ) + if not sync_op: + return output, task + else: + return output + + @staticmethod + def backward(ctx, *dx): + """ + Backward pass for All-to-All communication. + + Args: + dx: Gradient tensor + + Returns: + Tensor: Gradient after backward All-to-All + """ + return AlltoAll.apply(*dx, group=ctx.group) + +class AlltoAllAsync(PyLayer): + """ + Custom PyLayer for asynchronous All-to-All communication. + """ + @staticmethod + def forward(ctx, x, *fn_args, group=None, fn=None, is_first_fwd=False): + """ + Asynchronous All-to-All communication with function execution. + + Args: + x: Input tensor + fn_args: Arguments for the function + group: Communication group + fn: Function to execute + is_first_fwd: Whether this is the first forward pass + + Returns: + tuple: (output tensor, function outputs) + """ + assert fn is not None, "use AlltoAll no async" + ctx.group = group + if dist.get_world_size(group) <= 1: + ctx.bwf, fn_out = manual_backward(fn, is_first_fwd, *fn_args) + return (x,) + fn_out + x_out = paddle.empty_like(x) + x_out.stop_gradient = False + task = stream.alltoall_single( + x_out, + x, + None, + None, + group, + sync_op=False, + ) + ctx.bwf, fn_out = manual_backward(fn, is_first_fwd, *fn_args) + task.wait() + return (x_out,) + fn_out + + @staticmethod + def backward(ctx, dx_out, *fn_out_grads): + """ + Backward pass for asynchronous All-to-All. + + Args: + dx_out: Gradient of output + fn_out_grads: Gradients of function outputs + + Returns: + tuple: (gradient tensor, function argument gradients) + """ + if dist.get_world_size(ctx.group) <= 1: + fn_args_grads = ctx.bwf(*fn_out_grads) + return (dx_out,) + fn_args_grads + + dx = paddle.empty_like(dx_out) + dx.stop_gradient = False + task = stream.alltoall_single( + dx, + dx_out, + None, + None, + ctx.group, + sync_op=False, + ) + fn_args_grads = ctx.bwf(*fn_out_grads) + task.wait() + return (dx,) + fn_args_grads \ No newline at end of file diff --git a/paddleformers/nn/moe/alltoall_smart.py b/paddleformers/nn/moe/alltoall_smart.py new file mode 100644 index 00000000000..067b7314cb8 --- /dev/null +++ b/paddleformers/nn/moe/alltoall_smart.py @@ -0,0 +1,654 @@ +import inspect +from typing import Callable, Dict, List, Optional, Tuple + +import paddle +import paddle.distributed as dist +from paddle import framework, nn +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.distributed.communication.group import Group, _get_global_group +from paddle.distributed.fleet.utils import recompute +from paddle.incubate.nn.functional import ( + build_src_rank_and_local_expert_id, + expand_modality_expert_id, + moe_gate_dispatch_partial_nosoftmaxtopk, +) +from paddle.incubate.tensor.manipulation import async_offload +from paddleformers.peft.lora.lora_quantization_layers import QuantizationLoRALinear +from paddleformers.utils.log import logger + +from paddleformers.transformers.ernie4_5.distributed.common_dist_utils import ( + AllGatherGroupOp, + ReduceScatterGroupOp, + all_gather_group, + get_async_loader, + hack_offload_wait, + reduce_scatter_group, +) + +from .utils import manual_backward + +class AlltoAllSmart(paddle.autograd.PyLayer): + """ + Perform dispatch inputs alltoall. + """ + + @staticmethod + def forward( + ctx, + *inputs, + router_loss_fn: Optional[Callable], + forward_func_dict: Optional[Dict[int, Callable]], + local_expert_id=None, + send_rank_global=None, + recv_rank_global=None, + num_local_experts=None, + capacity=None, + use_padding=True, + expert_num_global=None, + is_first_fwd=None, + group=None, + recv_size=None, + send_counts=None, + recv_counts=None, + send_counts_num=None, + recv_counts_num=None, + ): + """Implements batched point-to-point communication with expert computation overlap. + + Functional Behavior: + - Performs distributed All-to-All communication with variable message sizes + - Overlaps expert forward computation with communication operations + - Calculates router loss for dynamic expert selection + - Handles padding/compression for irregular tensor shapes + + Key Operations: + 1. Prepare communication buffers based on send/recv counts + 2. Launch asynchronous All-to-All operations + 3. Execute expert forward functions in parallel with communication + 4. Calculate routing loss and prepare gradient masks + + Args: + ctx: PyLayer context object + *inputs: Variable-length expert inputs (Tensor[...]) + router_loss_fn: Routing loss calculator function + forward_func_dict: Expert-specific forward functions {expert_id: callable} + local_expert_id: Tensor indicating local expert assignments + send_rank_global: Global ranks for sending data + recv_rank_global: Global ranks for receiving data + num_local_experts: Number of experts per device + capacity: Maximum tokens per expert + use_padding: Enable padding for fixed-size buffers + expert_num_global: Global expert count + is_first_fwd: Flag for activation checkpointing + group: Process group for communication + recv_size: Precomputed receive buffer size + send_counts: Per-expert send counts [num_local_experts, world_size] + recv_counts: Per-expert recv counts [num_local_experts, world_size] + send_counts_num: Aggregated send expert + recv_counts_num: Aggregated recv counts per expert + + Returns: + tuple: (output_tensor, router_loss, gradient_mask) + """ + if group is None: + group = _get_global_group() + router_loss_args = inputs[num_local_experts:] + inputs = inputs[:num_local_experts] + + ctx.group = group + ctx.use_padding = use_padding + ctx.num_local_experts = num_local_experts + ctx.input_shape = [i.shape if i is not None else None for i in inputs] + + this_rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + capacity = len(send_rank_global) // world_size // num_local_experts + ctx.capacity = capacity + assert len(local_expert_id) == len(recv_rank_global), ( + len(local_expert_id), + len(recv_rank_global), + ) + + for i in inputs: + if i is not None: + input_dtype = i.dtype + input_shape = i.shape + break + else: + raise RuntimeError("all inputs are None") + + output = paddle.zeros([recv_size] + input_shape[1:], dtype=input_dtype) + output_ptr = 0 + + tasks = [] + dummy_input = paddle.empty([0] + input_shape[1:], dtype=input_dtype) + ctx.dummy_input = dummy_input + ctx.bw_funcs = {} + + for i_local_expert in range(num_local_experts): + send_count = send_counts[i_local_expert] + recv_count = recv_counts[i_local_expert] + assert len(recv_count) == len(send_count) == (world_size), ( + len(recv_count), + len(send_count), + ) + + if send_counts_num[i_local_expert] > 0: + input_local_expert = inputs[i_local_expert].slice( + (0,), 0, send_counts_num[i_local_expert] + ) + if forward_func_dict is not None: + input_local_expert.stop_gradient = False + bwf, (input_local_expert,) = manual_backward( + forward_func_dict[i_local_expert], + is_first_fwd, + input_local_expert, + ) + ctx.bw_funcs[i_local_expert] = bwf + + if input_local_expert is None: + input_local_expert = dummy_input + input_local_expert.stop_gradient = True + else: + input_local_expert = dummy_input + if recv_counts_num[i_local_expert] > 0: + # When FLAGS_use_stride_kernel=0, tensor.slice(...) returns a + # new tensor instead of a view, causing in-place assignment to fail. + # tensor._slice ensures it always returns a view. + # See: + # https://github.com/PaddlePaddle/Paddle/blob/release/3.1/paddle/phi/core/dense_tensor_impl.cc#L299 + output_local_expert = output._slice( + output_ptr, (output_ptr + recv_counts_num[i_local_expert]) + ) + else: + output_local_expert = dummy_input + + output_ptr += recv_counts_num[i_local_expert] + + if group.nranks <= 1: + output_local_expert[:] = input_local_expert[:] + else: + tasks.append( + dist.stream.alltoall_single( + output_local_expert, + input_local_expert, + recv_count, + send_count, + group=group, + sync_op=False, + use_calc_stream=False, + ) + ) + ctx.router_loss_bwfn, (router_loss,) = manual_backward( + router_loss_fn, is_first_fwd, *router_loss_args + ) + with paddle.no_grad(): + recv_mask = (recv_rank_global == this_rank).astype(send_rank_global.dtype) + if ctx.use_padding: + recv_mask_alltoall_out = ( + recv_mask.reshape([-1, num_local_experts, capacity]) + .transpose([1, 0, 2]) + .reshape([-1]) + ) + distributed_input_to_alltoall_out = paddle.maximum( + recv_mask_alltoall_out.cumsum() - 1, + paddle.zeros([1], dtype=recv_mask_alltoall_out.dtype), + ) + distributed_input_to_alltoall_out = ( + distributed_input_to_alltoall_out.view( + [num_local_experts, -1, capacity] + ) + .transpose([1, 0, 2]) + .reshape([-1]) + ) + else: + recv_mask_alltoall_out = recv_mask.split( + expert_num_global + ) # h->d copy break overlap + recv_mask_alltoall_out = [ + recv_mask_alltoall_out[ + (iexpert % world_size) * num_local_experts + + (iexpert // world_size) + ] + for iexpert in range(world_size * num_local_experts) + ] + alltoall_shape = [i.shape[0] for i in recv_mask_alltoall_out] + + recv_mask_alltoall_out = paddle.concat(recv_mask_alltoall_out, 0) + distributed_input_to_alltoall_out = paddle.maximum( + recv_mask_alltoall_out.cumsum() - 1, + paddle.zeros([1], dtype=recv_mask_alltoall_out.dtype), + ) + distributed_input_to_alltoall_out = ( + distributed_input_to_alltoall_out.split(alltoall_shape) + ) + + distributed_input_to_alltoall_out = paddle.concat( + [ + distributed_input_to_alltoall_out[ + (iexpert % num_local_experts) * world_size + + (iexpert // num_local_experts) + ] + for iexpert in range(world_size * num_local_experts) + ], + 0, + ) + + distributed_input_to_alltoall_out.stop_gradient = True + for t in tasks: + t and t.wait() + ctx.send_counts = send_counts + ctx.recv_counts = recv_counts + return output, router_loss, distributed_input_to_alltoall_out + + @staticmethod + def backward( + ctx, + out_grad, + d_routerloss, + _, # scatter-idx no grad + ): + """Performs distributed gradient propagation for expert-parallel models. + + Functional Behavior: + - Distributes output gradients via reverse All-to-All communication + - Computes expert-specific gradients using stored backward functions + - Aggregates routing loss gradients + + Key Operations: + 1. Prepare gradient buffers based on forward pass metadata + 2. Execute reverse All-to-All communication + 3. Apply expert-specific backward computations + 4. Combine gradients from all sources + + Args: + ctx: Context object storing forward pass information + out_grad (Tensor): Gradient from downstream layers + d_routerloss (Tensor): Routing loss gradient + _: Ignored placeholder + + Returns: + tuple: Combined gradients (expert gradients + router loss gradients) + """ + + grads = [ + paddle.zeros(s, dtype=out_grad.dtype) if s is not None else None + for s in ctx.input_shape + ] + assert len(grads) == ctx.num_local_experts + out_ptr = 0 + tasks = [] + tmp_g = [] + send_counts_num = ctx.send_counts.sum(-1) + recv_counts_num = ctx.recv_counts.sum(-1) + out_grad = out_grad.contiguous() + for i_local_expert in range(ctx.num_local_experts): + send_count = ctx.send_counts[i_local_expert] + recv_count = ctx.recv_counts[i_local_expert] + if recv_counts_num[i_local_expert] > 0: + out_g = out_grad.slice( + (0,), out_ptr, out_ptr + recv_counts_num[i_local_expert] + ) + else: + out_g = ( + ctx.dummy_input + ) # paddle.empty([0,]+out_grad.shape[1:], dtype=out_grad.dtype) + if send_counts_num[i_local_expert] > 0: + # When FLAGS_use_stride_kernel=0, tensor.slice(...) returns a + # new tensor instead of a view, causing in-place assignment to fail. + # tensor._slice ensures it always returns a view. + # See: + # https://github.com/PaddlePaddle/Paddle/blob/release/3.1/paddle/phi/core/dense_tensor_impl.cc#L299 + g = grads[i_local_expert]._slice(0, send_counts_num[i_local_expert]) + else: + g = ctx.dummy_input + tmp_g.append(g) + out_ptr += recv_counts_num[i_local_expert] + if ctx.group.nranks <= 1: + g[:] = out_g[:] + else: + task = dist.stream.alltoall_single( + g, + out_g, + send_count, + recv_count, + group=ctx.group, + sync_op=False, + use_calc_stream=False, + ) + tasks.append(task) + router_fn_args_grad = ctx.router_loss_bwfn(d_routerloss) + + for i_local_expert, t in enumerate(tasks): + t and t.wait() + send_cnt = send_counts_num[i_local_expert] + if send_cnt > 0 and ctx.bw_funcs: + (g,) = ctx.bw_funcs[i_local_expert](tmp_g[i_local_expert]) + grads[i_local_expert][:send_cnt] = g + + grads = [g for g in grads if g is not None] + return tuple(grads) + tuple(router_fn_args_grad) + +class AlltoAllSmartXPU(paddle.autograd.PyLayer): + """ + Perform dispatch inputs alltoall. (XPU VERSION) + """ + + @staticmethod + def forward( + ctx, + *inputs, + router_loss_fn: Optional[Callable], + forward_func_dict: Optional[Dict[int, Callable]], + local_expert_id=None, + send_rank_global=None, + recv_rank_global=None, + num_local_experts=None, + capacity=None, + use_padding=True, + expert_num_global=None, + is_first_fwd=None, + group=None, + recv_size=None, + send_counts=None, + recv_counts=None, + send_counts_num=None, + recv_counts_num=None, + ): + if group is None: + group = _get_global_group() + router_loss_args = inputs[num_local_experts:] + inputs = inputs[:num_local_experts] + + ctx.group = group + ctx.use_padding = use_padding + ctx.num_local_experts = num_local_experts + ctx.input_shape = [i.shape if i is not None else None for i in inputs] + ctx.send_counts = send_counts + ctx.recv_counts = recv_counts + ctx.send_counts_num = send_counts_num + ctx.recv_counts_num = recv_counts_num + + world_size = dist.get_world_size(group) + this_rank = dist.get_rank(group) + if use_padding and capacity is None: + capacity = len(send_rank_global) // world_size // num_local_experts + + for i in inputs: + if i is not None: + input_dtype = i.dtype + input_shape = i.shape + break + else: + first_expert = forward_func_dict[0] + input_dtype = first_expert.up_gate_proj.weight.dtype + hidden_size = first_expert.up_gate_proj.weight.shape[0] + input_shape = [0, hidden_size] + + dummy_input = paddle.empty([0] + input_shape[1:], dtype=input_dtype) + ctx.dummy_input = dummy_input + ctx.bw_funcs = {} + + processed_inputs = [] + no_tokens_expert_outputs = [] + + for i_local_expert in range(num_local_experts): + if send_counts_num[i_local_expert] > 0: + input_local_expert = inputs[i_local_expert].slice( + (0,), 0, send_counts_num[i_local_expert] + ) + if forward_func_dict is not None: + input_local_expert.stop_gradient = False + bwf, (processed_input,) = manual_backward( + forward_func_dict[i_local_expert], + is_first_fwd, + input_local_expert, + ) + ctx.bw_funcs[i_local_expert] = bwf + processed_input.stop_gradient = True + else: + processed_input = input_local_expert + processed_inputs.append(processed_input) + elif forward_func_dict is not None: + expert_func = forward_func_dict[i_local_expert] + fake_chunk = paddle.zeros( + [1, expert_func.up_gate_proj.weight.shape[0]], + dtype=expert_func.up_gate_proj.weight.dtype, + ) + if expert_func.training: + fake_chunk.stop_gradient = False + + _, (expert_out,) = manual_backward( + expert_func, is_first_fwd, fake_chunk + ) + + no_tokens_expert_outputs.append(expert_out * 0.0) + + all_processed_inputs = ( + paddle.concat(processed_inputs, axis=0) if processed_inputs else dummy_input + ) + + if no_tokens_expert_outputs: + if all_processed_inputs.shape[0] > 0: + all_processed_inputs[0] = all_processed_inputs[0] + sum( + no_tokens_expert_outputs + ) + else: + router_loss_args = list(router_loss_args) + router_loss_args[0] = ( + router_loss_args[0] + sum(no_tokens_expert_outputs).mean() * 0.0 + ) + + in_tensors_by_rank = [[] for _ in range(world_size)] + processed_input_ptr = 0 + for i_local_expert in range(num_local_experts): + num_tokens = send_counts_num[i_local_expert] + if num_tokens > 0: + expert_input = all_processed_inputs.slice( + [0], processed_input_ptr, processed_input_ptr + num_tokens + ) + processed_input_ptr += num_tokens + splits = expert_input.split( + send_counts[i_local_expert].tolist(), axis=0 + ) + for j_rank in range(world_size): + in_tensors_by_rank[j_rank].append(splits[j_rank]) + + in_tensor_list = [ + paddle.concat(tensors, 0) if tensors else dummy_input + for tensors in in_tensors_by_rank + ] + + all_to_all_input = paddle.concat(in_tensor_list, 0) + send_counts_for_api = [t.shape[0] for t in in_tensor_list] + + recv_counts_tensor = paddle.to_tensor(recv_counts) + recv_counts_for_api = [ + int(recv_counts_tensor[:, j_rank].sum()) for j_rank in range(world_size) + ] + temp_output = paddle.empty( + [recv_size.item()] + input_shape[1:], dtype=input_dtype + ) + + if group.nranks <= 1: + task = None + if all_to_all_input.shape[0] > 0: + temp_output[:] = all_to_all_input[:] + else: + task = dist.stream.alltoall_single( + temp_output, + all_to_all_input, + recv_counts_for_api, + send_counts_for_api, + group=group, + sync_op=False, + use_calc_stream=False, + ) + + ctx.router_loss_bwfn, (router_loss,) = manual_backward( + router_loss_fn, is_first_fwd, *router_loss_args + ) + with paddle.no_grad(): + recv_mask = (recv_rank_global == this_rank).astype(send_rank_global.dtype) + if ctx.use_padding: + recv_mask_alltoall_out = ( + recv_mask.reshape([-1, num_local_experts, capacity]) + .transpose([1, 0, 2]) + .reshape([-1]) + ) + distributed_input_to_alltoall_out = paddle.maximum( + recv_mask_alltoall_out.cumsum() - 1, + paddle.zeros([1], dtype=recv_mask_alltoall_out.dtype), + ) + distributed_input_to_alltoall_out = ( + distributed_input_to_alltoall_out.view( + [num_local_experts, -1, capacity] + ) + .transpose([1, 0, 2]) + .reshape([-1]) + ) + else: + recv_mask_alltoall_out = recv_mask.split(expert_num_global) + recv_mask_alltoall_out = [ + recv_mask_alltoall_out[ + (iexpert % world_size) * num_local_experts + + (iexpert // world_size) + ] + for iexpert in range(world_size * num_local_experts) + ] + alltoall_shape = [i.shape[0] for i in recv_mask_alltoall_out] + recv_mask_alltoall_out = paddle.concat(recv_mask_alltoall_out, 0) + distributed_input_to_alltoall_out = paddle.maximum( + recv_mask_alltoall_out.cumsum() - 1, + paddle.zeros([1], dtype=recv_mask_alltoall_out.dtype), + ) + distributed_input_to_alltoall_out = ( + distributed_input_to_alltoall_out.split(alltoall_shape) + ) + distributed_input_to_alltoall_out = paddle.concat( + [ + distributed_input_to_alltoall_out[ + (iexpert % num_local_experts) * world_size + + (iexpert // num_local_experts) + ] + for iexpert in range(world_size * num_local_experts) + ], + 0, + ) + + distributed_input_to_alltoall_out.stop_gradient = True + + if task is not None: + task.wait() + + temp_output_splits_by_src_rank = temp_output.split(recv_counts_for_api, 0) + chunks_by_expert = [[] for _ in range(num_local_experts)] + for j_rank in range(world_size): + data_from_j = temp_output_splits_by_src_rank[j_rank] + expert_chunks_from_j = data_from_j.split(recv_counts[:, j_rank].tolist(), 0) + for i_expert in range(num_local_experts): + chunks_by_expert[i_expert].append(expert_chunks_from_j[i_expert]) + + output_chunks = [] + for i_expert in range(num_local_experts): + if recv_counts_num[i_expert] > 0: + output_chunks.append(paddle.concat(chunks_by_expert[i_expert], 0)) + output = paddle.concat(output_chunks, 0) if output_chunks else dummy_input + + return output, router_loss, distributed_input_to_alltoall_out + + @staticmethod + def backward( + ctx, + out_grad, + d_routerloss, + _, # scatter-idx no grad + ): + world_size = dist.get_world_size(ctx.group) + num_local_experts = ctx.num_local_experts + dummy_input = ctx.dummy_input + out_grad = out_grad.contiguous() + + send_counts_bw = ctx.recv_counts + send_counts_num_bw = ctx.recv_counts_num + in_tensors_by_rank_bw = [[] for _ in range(world_size)] + grad_ptr = 0 + for i_expert in range(num_local_experts): + num_tokens = send_counts_num_bw[i_expert] + if num_tokens > 0: + expert_grad = out_grad.slice([0], grad_ptr, grad_ptr + num_tokens) + grad_ptr += num_tokens + splits = expert_grad.split(send_counts_bw[i_expert].tolist(), 0) + for j_rank in range(world_size): + in_tensors_by_rank_bw[j_rank].append(splits[j_rank]) + in_tensor_list_bw = [ + paddle.concat(tensors, 0) if tensors else dummy_input + for tensors in in_tensors_by_rank_bw + ] + + all_to_all_grad_input = paddle.concat(in_tensor_list_bw, 0) + send_counts_bw_for_api = [t.shape[0] for t in in_tensor_list_bw] + + recv_counts_bw = ctx.send_counts + recv_counts_tensor_bw = paddle.to_tensor(recv_counts_bw) + recv_counts_bw_for_api = [ + int(recv_counts_tensor_bw[:, j_rank].sum()) for j_rank in range(world_size) + ] + total_output_grad_size = int(ctx.send_counts_num.sum()) + temp_grad_output = paddle.empty( + [total_output_grad_size] + list(out_grad.shape[1:]), dtype=out_grad.dtype + ) + + if ctx.group.nranks <= 1: + task = None + if all_to_all_grad_input.shape[0] > 0: + temp_grad_output[:] = all_to_all_grad_input[:] + else: + task = dist.stream.alltoall_single( + temp_grad_output, + all_to_all_grad_input, + recv_counts_bw_for_api, + send_counts_bw_for_api, + group=ctx.group, + sync_op=False, + use_calc_stream=False, + ) + + router_fn_args_grad = ctx.router_loss_bwfn(d_routerloss) + + if task is not None: + task.wait() + + temp_grad_output_splits = temp_grad_output.split(recv_counts_bw_for_api, 0) + grad_chunks_by_expert = [[] for _ in range(num_local_experts)] + for j_rank in range(world_size): + data_from_j = temp_grad_output_splits[j_rank] + expert_chunks_from_j = data_from_j.split( + recv_counts_bw[:, j_rank].tolist(), 0 + ) + for i_expert in range(num_local_experts): + grad_chunks_by_expert[i_expert].append(expert_chunks_from_j[i_expert]) + + grads = [ + paddle.zeros(s, dtype=out_grad.dtype) if s is not None else None + for s in ctx.input_shape + ] + for i_expert in range(num_local_experts): + num_tokens = ctx.send_counts_num[i_expert] + if num_tokens > 0: + reconstructed_grad = paddle.concat(grad_chunks_by_expert[i_expert], 0) + if i_expert in ctx.bw_funcs: + (final_grad,) = ctx.bw_funcs[i_expert](reconstructed_grad) + else: + final_grad = reconstructed_grad + if grads[i_expert] is not None: + grads[i_expert][:num_tokens] = final_grad + + grads = [g for g in grads if g is not None] + return tuple(grads) + tuple(router_fn_args_grad) + + +# Conditionally select the AlltoAllSmart implementation +if paddle.is_compiled_with_xpu(): + AlltoAllSmart = AlltoAllSmartXPU \ No newline at end of file diff --git a/paddleformers/nn/moe/moe.py b/paddleformers/nn/moe/moe.py new file mode 100644 index 00000000000..24c53703f12 --- /dev/null +++ b/paddleformers/nn/moe/moe.py @@ -0,0 +1,242 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.distributed as dist +from paddle import nn +from paddle.distributed import fleet +from paddle.distributed.communication.group import Group +from paddleformers.utils.log import logger + +from .moe_alltoall import moe_alltoall_forward +from .moe_allgather import moe_allgather_forward + +class MOE(nn.Layer): + _global_mapping = { + "alltoall": moe_forward, + "allgather": moe_allgather_forward, + } + + def __init__( + self, + gate: nn.Layer, + experts: List[nn.Layer], + layer_idx, + shared_experts: Optional[List[nn.Layer]] = None, + group: Group = None, + recompute=False, + k=2, + enable_reverse_token_drop=False, + all_to_all_dropout=0, + group_experts=False, + use_expert_out_alltoall=True, # + use_padding=True, + dense_token_type=3, # considerd as dense tokens (no moe) + moe_statics=None, + moe_num_experts=None, + moe_mode="allgather", + ): + """ + Initialize MoE layer. + + Args: + gate: Gate network for expert selection + experts: List of expert networks + layer_idx: Index of this layer in the model + group: Distributed communication group + recompute: Whether to enable recomputation + k: Number of experts to select per token + all_to_all_dropout: Dropout rate for all-to-all communication + group_experts: Whether to group experts + moe_statics: MoE statistics tracking object + """ + super().__init__() + self.gate = gate + self.layer_idx = layer_idx + self.recompute = recompute + for p in self.gate.parameters(): + p.is_gate = True + if isinstance(experts, nn.LayerList): + self.experts = experts + else: + logger.info(f"using fused experts, type={type(experts)}") + self.experts = experts + self.shared_experts = shared_experts + + self.group = group + self.k = k + self.all_to_all_dropout = all_to_all_dropout + self.use_correction_bias = moe_statics is not None + self.moe_statics = moe_statics + if self.use_correction_bias: + logger.info( + f"using correction bias, aux-coef:{self.gate.config.moe_aux_loss_lambda}" + ) + assert self.gate.config.moe_use_aux_free + + self.is_mp_moe = ( + hasattr(fleet.fleet, "_hcg") + and group is fleet.get_hybrid_communicate_group().get_model_parallel_group() + ) + is_dummy_moe = dist.get_world_size(group) == 1 + + for p in experts.parameters(): + p.expert = not (self.is_mp_moe or is_dummy_moe) # type: ignore + p.no_sync = not (self.is_mp_moe or is_dummy_moe) + if self.is_mp_moe: + p.is_distributed = True + p.mp_moe = True + + self.world_size = dist.get_world_size(self.group) + # assert self.world_size > 1, f'moe-group not found, world_size {self.world_size}' + self.rank = dist.get_rank(self.group) + if self.world_size < 1: + self.world_size = 1 + if self.rank < 0: + self.rank = 0 + + # self.multimodal_experts = ( + # isinstance(moe_num_experts, (tuple, list)) and len(moe_num_experts) > 1 + # ) + self.num_local_experts = len(self.experts) // self.world_size + # if self.multimodal_experts: + # self.num_local_multimodal_experts = [ + # num // self.world_size for num in moe_num_experts + # ] + # self.multimodal_expert_index = [0] + list( + # itertools.accumulate(moe_num_experts) + # ) + + self.input_preprocess = self.output_postprocess = None + self.group_experts = group_experts + self.config = self.gate.config + # self.zero = paddle.to_tensor(0, dtype=paddle.float32) + self.moe_mode = moe_mode + + if (self.moe_mode == "allgather"): + self.enable_reverse_token_drop = enable_reverse_token_drop + self.is_allgather_moe_layer = is_allgather_moe_layer + self.use_padding = use_padding + + # 全局 gate gather + self.send_rank = None + self.local_expert_id = None + self.dense_experts = None + self.dense_token_type = dense_token_type + self.capacity_tensor = None + self.use_expert_out_alltoall = use_expert_out_alltoall + logger.info( + f"uisng MOEAllGatherLayerV2, use_expert_out_alltoall={use_expert_out_alltoall}, " # false + f"use_padding={use_padding}, enable_reverse_token_drop={self.enable_reverse_token_drop}" # true false + ) + # self.two = paddle.to_tensor(2, dtype=paddle.float32) + + def forward( + self, + input: paddle.Tensor, + token_type_ids=None, + use_dense_expert=False, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + if (self.moe_mode == "allgather"): + return moe_allgather_forward( + input=input, + token_type_ids=token_type_ids, + use_dense_expert=use_dense_expert + config=self.config, + gate=self.gate, + k=self.k, + use_correction_bias=self.use_correction_bias, + moe_statics=self.moe_statics, + world_size=self.world_size, + num_local_experts=self.num_local_experts, + shared_experts=self.shared_experts, + group=self.group, + experts=self.experts, + rank=self.rank, + isRecompute=self.recompute, + isTraining=self.training, + layer_idx=self.layer_idx, + dense_token_type=self.dense_token_type,) + elif (self.moe_mode == "alltoall"): + return moe_alltoall_forward( + input=input, + token_type_ids=token_type_ids, + config=self.config, + gate=self.gate, + k=self.k, + use_correction_bias=self.use_correction_bias, + moe_statics=self.moe_statics, + world_size=self.world_size, + num_local_experts=self.num_local_experts, + shared_experts=self.shared_experts, + group=self.group, + experts=self.experts, + rank=self.rank, + isRecompute=self.recompute, + isTraining=self.training, + layer_idx=self.layer_idx, + ) + else: + raise ValueError("Unsupported MOE mode: {}".format(self.moe_mode)) + + +class MoEStatics(nn.Layer): + """ + Stores MoE (Mixture of Experts) statistics + and expert usage information. + """ + + def __init__(self, config, layer_idx): + """ + Initialize MoE statistics tracking. + + Args: + config: Model configuration containing MoE parameters + layer_idx: Index of the MoE layer in the model + """ + super().__init__() + self._cast_to_low_precision = False # 兼容develop分支paddle + self._cast_to_low_precison = False + num_experts = ( + config.moe_num_experts[0] + if config.multimodel_experts + else config.moe_num_experts + ) + # if config.multimodel_experts: + # assert ( + # len(set(config.moe_num_experts)) == 1 + # ), f"assume expert group has same size, got: {config.moe_num_experts}" + + with paddle.utils.unique_name.guard(f"mm_layer_{layer_idx}_"): + num_experts_groups = ( + len(config.moe_num_experts) if config.multimodel_experts else 1 + ) + p = self.create_parameter( + shape=[num_experts_groups, num_experts], + dtype="float32", + is_bias=True, + attr=paddle.ParamAttr( + name=paddle.utils.unique_name.generate("corr_bias") + ), + ) + p.stop_gradient = True + self.e_score_correction_bias = p + self.e_score_correction_bias.is_distributed = True + p = paddle.zeros( + shape=[num_experts_groups, num_experts], + dtype="int64", + ) + p.stop_gradient = True + self.expert_usage = p + diff --git a/paddleformers/nn/moe/moe_allgather.py b/paddleformers/nn/moe/moe_allgather.py new file mode 100644 index 00000000000..d3886f735b9 --- /dev/null +++ b/paddleformers/nn/moe/moe_allgather.py @@ -0,0 +1,901 @@ +import inspect +from typing import Callable, Dict, List, Optional, Tuple + +import paddle +import paddle.distributed as dist +from paddle import framework, nn +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.distributed.communication.group import Group, _get_global_group +from paddle.distributed.fleet.utils import recompute +from paddle.incubate.nn.functional import ( + build_src_rank_and_local_expert_id, + expand_modality_expert_id, + moe_gate_dispatch_partial_nosoftmaxtopk, +) +from paddle.incubate.tensor.manipulation import async_offload +from paddleformers.peft.lora.lora_quantization_layers import QuantizationLoRALinear +from paddleformers.utils.log import logger + +from paddleformers.transformers.ernie4_5.distributed.common_dist_utils import ( + AllGatherGroupOp, + ReduceScatterGroupOp, + all_gather_group, + get_async_loader, + hack_offload_wait, + reduce_scatter_group, +) + +from .utils import manual_backward, combine_expert_output, _calc_router_loss, ReshardCombineWeight +from .allgather import AllGatherAsync +from .alltoall_smart import AlltoAllSmart + +def moe_allgather_forward( + input: paddle.Tensor, + token_type_ids=None, + use_dense_expert=False, + config, + gate: nn.Layer, + k, + use_correction_bias, + moe_statics, + world_size, + num_local_experts, + shared_experts=None, + group: Group = None, + experts: List[nn.Layer], + rank, + isRecompute, + isTraining, + layer_idx, + dense_token_type, +) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Implements forward pass for Mixture-of-Experts (MoE) layer with distributed communication. + + Core Functionality: + - Processes input through gating network to determine expert assignments + - Performs distributed All-to-All communication for expert computation + - Combines expert outputs and calculates routing loss + + Key Features: + 1. Supports both dense and sparse expert computation modes + 2. Implements fused gating and dispatch for performance optimization + 3. Handles sequence length padding/unpadding for irregular inputs + 4. Enables communication-computation overlap through asynchronous operations + + Args: + input (Tensor): Input tensor of shape [seq_len, hidden_dim] + token_type_ids: Optional segmentation markers for heterogeneous inputs + use_dense_expert: Flag to enable dense expert computation bypass + + Returns: + tuple: ( + combined_output: Aggregated expert outputs [seq_len, hidden_dim], + combine_weights: Expert combination coefficients, + router_loss: Calculated router balancing loss + ) + """ + if input.ndim == 3: + orig_shape = input.shape + input = input.reshape([-1, input.shape[-1]]) + else: + orig_shape = None + + assert ( + len(input.shape) == 2 + ), f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}" + dispatch_token_type_ids = None + global_dense_expert_mask = None + if token_type_ids is not None: + token_type_ids = token_type_ids[:, :-1].reshape([-1]) + dispatch_token_type_ids = token_type_ids + if config.sequence_parallel: + hcg = fleet.get_hybrid_communicate_group() + rank = hcg.get_model_parallel_rank() + interval = ( + token_type_ids.shape[0] // hcg.get_model_parallel_world_size() + ) + token_type_ids = token_type_ids.slice( + [0], rank * interval, (rank + 1) * interval + ) + token_type_ids.stop_gradient = True + + if use_dense_expert: + global_dense_expert_mask = ( + dispatch_token_type_ids == dense_token_type + ) + + assert gate is not None + # if hasattr(self, "rng") and self.rng.random() < self.all_to_all_dropout: + # orig_shape_2 = input.shape + # output = self.forward_experts(input) + # output += self.gate.weight.sum() * 0.0 # hack for grad + # output = output.reshape(orig_shape or orig_shape_2) # [e*1,c,m] + # return output, None, 0 + ( + dispatched_input, + global_hidden_states, + local_combine_weights, + expert_num_global_no_token_drop, + expert_num_global, + expert_num_global_list, + local_scatter_index, + scatter_index_rev, + router_loss, + (gate_logits, gate_prob), + (gate_logits_mm, gate_prob_mm), + expert_num_local, + ) = fused_gate_and_dispatch( + input=input, + token_type_ids=token_type_ids, + global_dense_expert_mask=global_dense_expert_mask, + gate=gate, + k=k, + use_correction_bias=use_correction_bias, + moe_statics=moe_statics, + world_size=world_size, + num_local_experts=num_local_experts, + shared_experts=shared_experts, + group=group, + experts=experts, + rank=rank, + isRecompute=isRecompute, + isTraining=isTraining, + ) + seqlen_this_mp = input.shape[0] + if len(scatter_index_rev): + recv_rank_local = scatter_index_rev // seqlen_this_mp + else: + recv_rank_local = scatter_index_rev + + # if self.use_padding: + # if self.send_rank is None: + # capacity = self.gate.get_capacity( + # input.shape[0] * self.config.moe_world_size + # ) + # self.send_rank = ( + # paddle.arange(self.config.moe_world_size) + # .repeat_interleave(capacity * self.num_local_experts) + # .astype("int32") # cap + # ) + # self.local_expert_id = ( + # paddle.arange(self.num_local_experts) + # .repeat_interleave(capacity) + # .tile(self.config.moe_world_size) + # .astype(self.send_rank.dtype) + # ) + # recv_rank, recv_rank_task = allgather_async( + # recv_rank_local, group=self.config.moe_group + # ) + # send_rank = self.send_rank + # local_expert_id = self.local_expert_id + + # else: + all_expert_num = sum(expert_num_global_list) + # 非常慢 + if config.moe_group.nranks > 1: + recv_rank = paddle.empty([all_expert_num], dtype=recv_rank_local.dtype) + # 非常慢 + recv_rank_task = dist.stream.alltoall_single( + recv_rank, + recv_rank_local.tile(config.moe_world_size), + [ + sum( + expert_num_global_list[ + i + * num_local_experts : (i + 1) + * num_local_experts + ] + ) + for i in range(config.moe_world_size) + ], # output-size + [len(recv_rank_local)] * config.moe_world_size, # input-size + group=config.moe_group, + sync_op=False, + use_calc_stream=False, + ) + else: + recv_rank_task = None + recv_rank = recv_rank_local.tile(config.moe_world_size) + + send_rank, local_expert_id = build_src_rank_and_local_expert_id( + expert_num_global, expert_num_global_list, num_local_experts + ) + + # if not self.use_expert_out_alltoall: + # expert_outs = ( + # recompute(self.forward_experts, *dispatched_input) + # if self.recompute and self.training + # else self.forward_experts(*dispatched_input) + # ) + # expert_outs = paddle.concat( + # [e for e in expert_outs if e is not None], axis=0 + # ) # [e*c,m] + # expert_out_to_combine = AllGatherGroupOp.apply( + # expert_outs, group=self.config.moe_group + # ) # for test + # router_loss2 = self.calc_router_loss_and_logging( + # router_loss, + # gate_logits, + # gate_prob, + # gate_logits_mm, + # gate_prob_mm, + # local_combine_weights, + # expert_num_global_no_token_drop, + # token_type_ids, + # dispatch_token_type_ids, + # ) + # else: + recv_rank_task and recv_rank_task.wait() # wait for recv_rank + + world_size = dist.get_world_size(config.moe_group) + this_rank = dist.get_rank(config.moe_group) + + recv_size = paddle.count_nonzero( + recv_rank == dist.get_rank(config.moe_group) + ) + recv_size = paddle.maximum( + recv_size, paddle.ones([], dtype=recv_size.dtype) + ) + + recv_size_cpu, recv_size_task = async_offload(recv_size, get_async_loader()) + + send_rank_this_rank = paddle.count_nonzero(send_rank == this_rank) + + send_rank_this_rank_cpu, send_rank_this_rank_task = async_offload( + send_rank_this_rank, get_async_loader() + ) + + recv_rank[recv_rank == -1] = world_size + send_recv_count_global = paddle.scatter_nd_add( + paddle.zeros( + [num_local_experts, world_size + 1, world_size + 1], + dtype="int32", + ), + paddle.stack([local_expert_id, send_rank, recv_rank], -1), + paddle.ones([len(send_rank)], dtype="int32"), + ) # [num_local_experts, world_size + 1 , world_size + 1] + send_counts_cpu = send_recv_count_global[:, this_rank, :-1].numpy() + recv_counts_cpu = send_recv_count_global[:, :-1, this_rank].numpy() + send_counts_num_cpu = send_counts_cpu.sum(-1) + recv_counts_num_cpu = recv_counts_cpu.sum(-1) + + dispatched_input = forward_experts(*dispatched_input, experts, rank, num_local_experts) + + if recv_size_task is not None: + recv_size_task.cpu_wait() + if send_rank_this_rank_task is not None: + send_rank_this_rank_task.cpu_wait() + + input_size = sum([len(i) if i is not None else 0 for i in dispatched_input]) + # if self.use_padding or input_size > 1: + if input_size > 1: + assert send_rank_this_rank_cpu.item() == input_size, ( + send_rank, + [len(i) if i is not None else 0 for i in dispatched_input], + ) + + expert_out_to_combine, router_loss2, distributed_input_to_alltoall_out = ( + AlltoAllSmart.apply( + *dispatched_input, + router_loss, + gate_logits, + gate_prob, + gate_logits_mm, + gate_prob_mm, + local_combine_weights, + expert_num_global_no_token_drop, + token_type_ids, + dispatch_token_type_ids, + gate, + layer_idx, + forward_func_dict=None, + router_loss_fn=calc_router_loss_and_logging, + local_expert_id=local_expert_id, + send_rank_global=send_rank, + recv_rank_global=recv_rank, + num_local_experts=num_local_experts, + # capacity=dispatched_input[0].shape[1] if self.use_padding else None, + capacity=None, + use_padding=False,#self.use_padding, + expert_num_global=expert_num_global_list, + is_first_fwd=not framework._dygraph_tracer()._has_grad, + group=config.moe_group, + recv_size=recv_size_cpu, + send_counts=send_counts_cpu, + recv_counts=recv_counts_cpu, + send_counts_num=send_counts_num_cpu, + recv_counts_num=recv_counts_num_cpu, + ) + ) + # /origin input -> distributed input/ => /origin-input -> alltoall out -input/ + local_scatter_index = distributed_input_to_alltoall_out[local_scatter_index] + local_scatter_index.stop_gradient = True + + # global -> local + combined_output = combine_expert_output( + expert_out_to_combine, local_combine_weights, local_scatter_index + ) + + if shared_experts is not None: + shared_out = shared_experts(input) + combined_output += shared_out + + if orig_shape: + combined_output = combined_output.reshape( + orig_shape[:-1] + [combined_output.shape[-1]] + ) + + return combined_output, local_combine_weights, router_loss2, gate_logits + +def forward_experts( + dispatched_input, + experts: List[nn.Layer], + rank, + num_local_experts + ): + """Execute expert model computations in sequence for Mixture-of-Experts (MoE) layer. + + Core Functionality: + - Distributes dispatched tokens to local expert models + - Handles empty expert inputs with zero-initialized fallback + - Maintains gradient flow for expert outputs + - Aggregates outputs from all active experts + + Args: + *dispatched_input: Variable-length expert-specific input tensors + + Returns: + list: Expert output tensors (None for inactive experts) + + Implementation Details: + 1. Processes valid expert inputs through corresponding expert models + 2. Generates dummy inputs for inactive experts to preserve model structure + 3. Aggregates dummy outputs to first active expert to maintain gradient flow + """ + expert_outputs = [] + assert isinstance(experts, nn.LayerList), type(experts) + + no_tokens_expert_outputs = [] + # if not self.multimodal_experts: + true_experts = experts[ + rank + * num_local_experts : (rank + 1) + * num_local_experts + ] + # else: + # true_experts = [] + # for i, num in enumerate(self.num_local_multimodal_experts): + # current_modal_experts = self.experts[ + # self.multimodal_expert_index[i] : self.multimodal_expert_index[ + # i + 1 + # ] + # ] + # true_experts.extend( + # current_modal_experts[self.rank * num : (self.rank + 1) * num] + # ) + + assert len(dispatched_input) == len(true_experts), ( + len(dispatched_input), + len(true_experts), + ) + + for iexpert, chunk in enumerate(dispatched_input): + if chunk is None: + # QuantizationLoRALinear can not call `.weight`. + if not isinstance( + true_experts[iexpert].up_gate_proj, QuantizationLoRALinear + ): + input_shape = [ + 1, + true_experts[iexpert].up_gate_proj.weight.shape[0], + ] + input_dtype = true_experts[iexpert].up_gate_proj.weight.dtype + else: + input_shape = [ + 1, + true_experts[iexpert].up_gate_proj.lora_A.shape[0], + ] + input_dtype = true_experts[iexpert].up_gate_proj.lora_A.dtype + + chunk = paddle.zeros( + input_shape, + input_dtype, + ) + if true_experts[iexpert].training: + chunk.stop_gradient = False + expert_out = true_experts[iexpert](chunk.contiguous()) + no_tokens_expert_outputs.append( + expert_out * 0.0 + ) # mutiply 0.0 to zero out and grad + + expert_outputs.append(None) + continue + + expert_out = true_experts[iexpert](chunk.contiguous()) + expert_outputs.append(expert_out) + + # if self.config.moe_layer_feed_fake_token and len(no_tokens_expert_outputs) > 0: + if len(no_tokens_expert_outputs) > 0: + first_has_tokens_idx = 0 + for idx, expert_out in enumerate(expert_outputs): + if expert_out is not None: + first_has_tokens_idx = idx + break + for idx, expert_out in enumerate(no_tokens_expert_outputs): + expert_outputs[first_has_tokens_idx] += expert_out + + return expert_outputs + +def fused_gate_logits_process_fused( + gate_logits_lm, + gate_logits_mm=None, + token_type_ids=None, + k, + gate, + use_correction_bias, + moe_statics +): + """Process gating logits for expert selection in Mixture-of-Experts (MoE) layers. + + Core Functionality: + - Transforms raw gating logits into expert selection weights and IDs + - Supports both grouped and standard expert selection modes + - Handles bias correction for improved expert load balancing + + Args: + gate_logits_lm (Tensor): Raw gating scores of shape [batch_size, total_experts] + + Returns: + tuple: ( + lm_weight_and_expert_id: Combined tensor containing selection weights + and expert IDs [batch_size, 2*top_k], + prob_flat: Flattened expert probabilities [batch_size, total_experts] + ) + """ + top_k = k + num_expert_per_rank_per_modality = ( + gate_logits_lm.shape[-1] // gate.config.moe_world_size + ) + group_size = gate_logits_lm.shape[-1] // top_k + # if self.group_experts: + # assert not self.use_correction_bias + # gate_logits_lm = gate_logits_lm.reshape( + # [gate_logits_lm.shape[0], top_k, -1] + # ) + # prob_lm = self.gate.act(gate_logits_lm) + # prob_lm_ = prob_lm + # weight_lm, expert_id_lm = prob_lm_.topk(k=1, axis=-1) + # weight_lm = weight_lm.reshape([gate_logits_lm.shape[0], -1]) + # group_size = gate_logits_lm.shape[-1] + # expert_id_lm = expert_id_lm.squeeze(-1) + # else: + prob_lm = gate.act(gate_logits_lm) + if use_correction_bias: + prob_lm_ = ( + prob_lm + moe_statics.e_score_correction_bias[0].detach() + ) + else: + prob_lm_ = prob_lm + weight_lm, expert_id_lm = prob_lm_.topk(k=top_k, axis=-1) + + if use_correction_bias: + batch_idx = ( + paddle.arange(prob_lm_.shape[0]).unsqueeze(-1).expand_as(expert_id_lm) + ) + weight_lm = prob_lm[batch_idx, expert_id_lm] # use correct bias + + expert_id_lm = expand_modality_expert_id( + expert_id_lm, + num_expert_per_modality=( + num_expert_per_rank_per_modality if token_type_ids is not None else 0 + ), + group_size=group_size, + modality_offset=0, + is_group_expert=False, # self.group_experts, + ) + expert_id_lm = expert_id_lm.reshape(weight_lm.shape) + lm_weight_and_expert_id = paddle.concat( + [weight_lm, expert_id_lm.astype("float32")], -1 + ) + + if token_type_ids is None or gate_logits_mm is None: + return ( + lm_weight_and_expert_id, + prob_lm.reshape([prob_lm.shape[0], -1]), + None, + ) + + prob_mm = gate.act(gate_logits_mm) + if use_correction_bias: + prob_mm_ = prob_mm + moe_statics.e_score_correction_bias[1].detach() + else: + prob_mm_ = prob_mm + weight_mm, expert_id_mm = prob_mm_.topk(k=top_k, axis=-1) + if use_correction_bias: + batch_idx = ( + paddle.arange(prob_lm_.shape[0]).unsqueeze(-1).expand_as(expert_id_lm) + ) + weight_mm = prob_mm[batch_idx, expert_id_mm] # use correct bias + + expert_id_mm = expand_modality_expert_id( + expert_id_mm, + num_expert_per_modality=num_expert_per_rank_per_modality, + group_size=group_size, + modality_offset=1, + is_group_expert=False, + ) + expert_id_mm = expert_id_mm.reshape(weight_mm.shape) + mm_weight_and_expert_id = paddle.concat( + [weight_mm, expert_id_mm.astype("float32")], -1 + ) + weight_and_expert = paddle.where( + (token_type_ids == 0).unsqueeze(-1), + lm_weight_and_expert_id, + mm_weight_and_expert_id, + ) + return weight_and_expert, prob_lm.reshape([prob_lm.shape[0], -1]), prob_mm + +def fused_gate_and_dispatch( + input, + token_type_ids=None, + global_dense_expert_mask=None, + gate=gate, + k=k, + use_correction_bias=use_correction_bias, + moe_statics=moe_statics, + world_size=world_size, + num_local_experts=num_local_experts, + ): + """Implements fused expert gating and token dispatch logic for Mixture-of-Experts (MoE) layers. + + Core Functionality: + - Computes expert selection probabilities and routing weights + - Performs distributed token-to-expert assignment + - Handles communication and synchronization in model-parallel environments + + Args: + input (Tensor): Input tensor of shape [seq_len, hidden_dim] + + Returns: + tuple: ( + dispatched_input: Expert-assigned tokens [num_experts, capacity, hidden_dim], + global_hidden_states: Full sequence representations, + local_combine_weights: Local expert combination weights, + expert_num_global_notrunc: Global expert token counts (without capacity truncation), + expert_num_global: Actual expert token counts, + expert_num_global_list: Per-expert token counts, + local_scatter_index: Local token reorganization indices, + scatter_index_rev: Reverse scattering indices, + router_loss: Calculated routing loss, + gate_outputs: Raw gating network outputs, + expert_num_local: Local expert utilization counts + ) + """ + seqlen, d_model = input.shape + args = () + if token_type_ids is not None: + token_type_ids = token_type_ids.reshape([-1]) + args = (token_type_ids,) + + router_loss = paddle.zeros([1], dtype="float32") + router_loss.stop_gradient = False + top_k = k + + def build_weights_and_expert_id( + input, + gate, + k, + use_correction_bias, + moe_statics): + nonlocal token_type_ids, args + logits, capacity, router_loss = gate( + input, *args, transform_weight=False + ) + # if gate.config.multimodel_experts: + # gate_logits_lm, gate_logits_mm = logits.chunk(2, axis=-1) + # else: + gate_logits_lm, gate_logits_mm = logits, None + + weigth_and_expert, gate_prob_lm, gate_prob_mm = ( + fused_gate_logits_process_fused( + gate_logits_lm, + gate_logits_mm, + token_type_ids if global_dense_expert_mask is None else None, + k=k, + gate=gate, + use_correction_bias=use_correction_bias, + moe_statics=moe_statics, + ) + ) + weigth_and_expert = AllGatherGroupOp.apply( + weigth_and_expert, group=gate.config.moe_group + ) + return ( + weigth_and_expert, + gate_logits_lm, + gate_logits_mm, + gate_prob_lm, + gate_prob_mm, + ) + + capacity = gate.get_capacity(input.shape[0]) * world_size + ( + global_hidden_states, + combine_weights_and_expert_id, + gate_logits_lm, + gate_logits_mm, + gate_prob_lm, + gate_prob_mm, + ) = AllGatherAsync.apply( + input, + input, + gate, + k, + use_correction_bias, + moe_statics, + fn=build_weights_and_expert_id, + group=gate.config.moe_group, + is_first_fwd=not framework._dygraph_tracer()._has_grad, + ) + combine_weights_unnorm, expert_id = combine_weights_and_expert_id.chunk( + 2, axis=-1 + ) + expert_id = expert_id.cast("int32") + expert_id.stop_gradient = True + num_experts = ( + sum(gate.config.moe_num_experts) + if isinstance(gate.config.moe_num_experts, (tuple, list)) + else gate.config.moe_num_experts + ) # all-experts = 96 + if global_dense_expert_mask is not None: + combine_weights_unnorm[global_dense_expert_mask] = 0.0 + expert_id[global_dense_expert_mask] = num_experts + num_experts += 1 + + if ( + "reverse_token_drop" + in inspect.signature(moe_gate_dispatch_partial_nosoftmaxtopk).parameters + ): + # compat_kwargs = {"reverse_token_drop": self.enable_reverse_token_drop} + compat_kwargs = {"reverse_token_drop": False} + else: + compat_kwargs = {} + + # Disable AMP because: + # - combine_weights_unnorm is fp32, global_hidden_states is bf16 + # - AMP O2 would upcast global_hidden_states to fp32, making dispatched_input fp32 + # - This is a data movement op with no computation, so upcasting is unnecessary + with paddle.amp.auto_cast(False): + ( + dispatched_input, + combine_weights_unnorm, + scatter_index, # input -> dispatched_input + scatter_index_rev, # dispatch-input -> input + expert_num_global, + expert_num_local, + ) = moe_gate_dispatch_partial_nosoftmaxtopk( + global_hidden_states, + combine_weights_unnorm, + expert_id, + top_k, + capacity, + num_experts, + False, # self.use_padding, + expert_start_index=num_local_experts * gate.config.moe_rank, + expert_end_index=num_local_experts * (gate.config.moe_rank + 1), + **compat_kwargs, + ) + + if use_correction_bias: + # if gate.config.multimodel_experts: + # # MLLM + # for i in range(len(self.moe_statics.expert_usage)): + # self.moe_statics.expert_usage[i] += expert_num_local[ + # self.gate.experts_type_mask[i] + # ].detach() + # else: + # # LLM + moe_statics.expert_usage[0] += expert_num_local.detach() + + # When use unpad , `moe_ops_partial` output likes `scatter_index_rev==[]`. + if scatter_index_rev.ndim == 0: + # assert not self.use_padding + scatter_index_rev = paddle.empty([0], dtype=scatter_index_rev.dtype) + + dispatched_input.stop_gradient = False + combine_weights_unnorm.stop_gradient = False + scatter_index.stop_gradient = True + expert_num_global.stop_gradient = True + expert_num_global_notrunc = expert_num_global + capacity_tensor = paddle.to_tensor(capacity, dtype=expert_num_global.dtype) + expert_num_global = paddle.minimum(expert_num_global, capacity_tensor) + + if global_dense_expert_mask is not None: + expert_num_global = expert_num_global[:-1] + expert_num_local = expert_num_local[:-1] + expert_num_global_notrunc = expert_num_global_notrunc[:-1] + + scatter_index = scatter_index.transpose([1, 0]) # [k,s] ->[s,k] + + last_local_expert = num_local_experts * gate.config.moe_rank + expert_offset_global = expert_num_global.cumsum() + + loader = get_async_loader() + expert_num_global_list, offload_task = async_offload(expert_num_global, loader) + # if self.use_padding: + # offset = last_local_expert * capacity + # else: + offset = ( + expert_offset_global[last_local_expert - 1] + if gate.config.moe_rank > 0 + else 0 + ) + local_combine_weights_unnorm = ReshardCombineWeight.apply( + combine_weights_unnorm.contiguous(), group=gate.config.moe_group + ) + local_scatter_index = ReduceScatterGroupOp.apply( + paddle.where( + combine_weights_unnorm > 0.0, + scatter_index + offset, + scatter_index, + ), + group=gate.config.moe_group, + ) + if gate.norm_gate_logits: + local_combine_weights = local_combine_weights_unnorm / paddle.clip( + local_combine_weights_unnorm.sum(-1, keepdim=True), min=1e-12 + ) + else: + local_combine_weights = local_combine_weights_unnorm + local_combine_weights = local_combine_weights.cast(dispatched_input.dtype) + # if self.use_padding: + # dispatched_input = dispatched_input.reshape( + # [self.num_local_experts, -1, d_model] + # ) + # dispatched_input = dispatched_input.unbind(0) + # else: + s = num_local_experts * gate.config.moe_rank + e = num_local_experts * (gate.config.moe_rank + 1) + expert_num_local = expert_num_local.tolist()[s:e] + expert_num_local_valid = [i for i in expert_num_local if i > 0] + valid_pos = [j for j, i in enumerate(expert_num_local) if i > 0] + if expert_num_local_valid: + dispatched_input_list = dispatched_input.split(expert_num_local_valid) + dispatched_input = [None] * len(expert_num_local) + for p, t in zip(valid_pos, dispatched_input_list): + dispatched_input[p] = t + else: + dispatched_input = [dispatched_input] + ( + [None] * (len(expert_num_local) - 1) + ) + + scatter_index.stop_gradient = True + scatter_index_rev.stop_gradient = True + if offload_task is not None: + hack_offload_wait(offload_task) + expert_num_global_list = expert_num_global_list.tolist() + + return ( + dispatched_input, + global_hidden_states, + local_combine_weights, + expert_num_global_notrunc, # for auxloss calculation. + expert_num_global, + expert_num_global_list, + local_scatter_index, + scatter_index_rev, + router_loss, + (gate_logits_lm, gate_prob_lm), + (gate_logits_mm, gate_prob_mm), + expert_num_local, + ) + +def calc_router_loss_and_logging( + router_loss, + gate_logits, + gate_prob, + gate_logits_mm, + gate_prob_mm, + combine_weights, + dispatch_mask, + token_type_ids, + dispatch_token_type_ids, + gate, + layer_idx + ): + """Calculate and aggregate router auxiliary loss for Mixture-of-Experts training. + + Core Functionality: + - Computes expert load balancing loss to prevent expert under-utilization + - Integrates multiple loss components from different routing stages + - Maintains gradient flow for routing mechanism optimization + + Args: + router_loss (Tensor): Accumulated router loss tensor + gate_logits (Tensor): Raw gating network outputs [batch_size, num_experts] + gate_prob (Tensor): Activated gating probabilities [batch_size, num_experts] + combine_weights (Tensor): Expert combination weights [batch_size, top_k] + dispatch_mask (Tensor): Token dispatch mask indicating expert assignments + + Returns: + Tensor: Updated router loss with new auxiliary components + """ + dispatch_mask_3d = dispatch_mask.reshape([gate.config.moe_world_size, -1]) + if token_type_ids is not None and gate.config.moe_use_hard_gate: + # MLLM + if not gate.weight.stop_gradient: + dispatch_tokens_mask = ( + dispatch_token_type_ids == 0 + if dispatch_token_type_ids is not None + else None + ) + lm_tokens_mask = (token_type_ids == 0).astype(gate_prob.dtype) + # hard code + lm_experts = ( + gate.num_experts[0] + if isinstance(gate.num_experts, (tuple, list)) + else gate.num_experts + ) + dispatch_mask_lm = dispatch_mask_3d[ + :, : lm_experts // gate.config.moe_world_size + ].reshape([-1]) + router_loss += _calc_router_loss( + dispatch_mask_lm, + gate_logits * lm_tokens_mask.unsqueeze(-1), + gate_prob * lm_tokens_mask.unsqueeze(-1), + gate.num_experts_list[0], + False, # self.group_experts, + layer_idx, + 0, # ortholoss + lm_tokens_mask, + dispatch_tokens_mask, + prefix="lm", + gate, + ) + else: + zero = paddle.to_tensor(0, dtype=paddle.float32) + router_loss += zero * gate_logits[0, 0] * gate_prob[0, 0] + if gate_prob_mm is not None: + mm_tokens_mask = (token_type_ids == 1).astype(gate_prob_mm.dtype) + dispatch_tokens_mask = ( + dispatch_token_type_ids == 1 + if dispatch_token_type_ids is not None + else None + ) + dispatch_mask_mm = dispatch_mask_3d[ + :, gate.num_experts[0] // gate.config.moe_world_size : + ].reshape([-1]) + + router_loss += _calc_router_loss( + dispatch_mask_mm, + gate_logits_mm * mm_tokens_mask.unsqueeze(-1), + gate_prob_mm * mm_tokens_mask.unsqueeze(-1), + gate.num_experts_list[1], + False, + layer_idx, + 1, + mm_tokens_mask, + dispatch_tokens_mask, + prefix="mm", + gate, + ) + + else: + # LLM + router_loss += _calc_router_loss( + dispatch_mask, + gate_logits, + gate_prob, + gate.num_experts_tensor, + False, #self.group_experts, + layer_idx, + 0, + paddle.ones([gate_prob.shape[0]], "bool"), + paddle.ones( + [gate.config.moe_world_size * gate_prob.shape[0]], "bool" + ), + prefix="lm", + ) + + return router_loss \ No newline at end of file diff --git a/paddleformers/nn/moe/moe_alltoall.py b/paddleformers/nn/moe/moe_alltoall.py new file mode 100644 index 00000000000..3036f25c6f4 --- /dev/null +++ b/paddleformers/nn/moe/moe_alltoall.py @@ -0,0 +1,488 @@ +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import Tensor, _C_ops, framework, nn +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.distributed.communication import stream +from paddle.distributed.communication.group import Group +from paddle.distributed.fleet.utils import recompute +from paddle.incubate.nn.functional import moe_combine, moe_gate_dispatch +from paddleformers.utils.log import logger +from paddleformers.transformers.ernie4_5.sequence_parallel_utils import ScatterOp + +from .utils import combine_expert_output, _calc_router_loss +from .alltoall import AlltoAll, AlltoAllAsync + +def moe_alltoall_forward( + input: Tensor, + token_type_ids=None, + config, + gate: nn.Layer, + k, + use_correction_bias, + moe_statics, + world_size, + num_local_experts, + shared_experts=None, + group: Group = None, + experts: List[nn.Layer], + rank, + isRecompute, + isTraining, + layer_idx, +) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """ + Forward pass through MoE layer. + + Args: + input: Input tensor of shape [s, d] + + Returns: + tuple: (output, combine_weights, router_loss, gate_logits) + """ + # assert len(input) == 1, "only single input Tensor supported" + if input.ndim == 3: + orig_shape = input.shape + input = input.reshape([-1, input.shape[-1]]) + else: + orig_shape = None + assert ( + len(input.shape) == 2 + ), f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}" + if token_type_ids is not None: + token_type_ids = token_type_ids.clone()[:, :-1] + if config.sequence_parallel: + token_type_ids = token_type_ids.reshape([-1]) + token_type_ids = ScatterOp.apply(token_type_ids) + token_type_ids.stop_gradient = True + + assert gate is not None + + # if hasattr(self, "rng") and self.rng.random() < self.all_to_all_dropout: + # orig_shape_2 = input.shape + # output = self.forward_experts(input) + # output += self.gate.weight.sum() * 0.0 # hack for grad + # output = output.reshape(orig_shape or orig_shape_2) # [e*1,c,m] + # return output, None, 0 + + is_first_fwd = not framework._dygraph_tracer()._has_grad + gate_input = input + + ( + dispatched_input, + combine_weights, + dispatch_mask, + scatter_index, + router_loss, + gate_logits, + gate_prob, + ) = gate_and_dispatch(gate_input, token_type_ids, gate, k, config, use_correction_bias, moe_statics, world_size, num_local_experts) + + use_async = shared_experts is not None + if use_async: + dispatched_input, shared_out = AlltoAllAsync.apply( + dispatched_input, + input, # args to shared-experts + group=group, + fn=shared_experts, + is_first_fwd=is_first_fwd, + ) + else: + dispatched_input = AlltoAll.apply(dispatched_input, group=group) + + expert_out = ( + recompute(forward_experts, dispatched_input, experts, rank, num_local_experts, world_size) + if isRecompute and isTraining + else forward_experts(dispatched_input, experts, rank, num_local_experts, world_size) + ) + + expert_out, router_loss2 = AlltoAllAsync.apply( + expert_out, + router_loss, + combine_weights, + dispatch_mask, + gate_logits, + gate_prob, + token_type_ids, + gate, + layer_idx, + group=group, + fn=calc_router_loss_and_logging, + is_first_fwd=is_first_fwd, + ) + + combined_output = combine_expert_output( + expert_out, combine_weights, scatter_index + ) + + if shared_experts is not None: + combined_output += shared_out + + if orig_shape: + combined_output = combined_output.clone().reshape( + orig_shape[:-1] + [combined_output.shape[-1]] + ) + return combined_output, combine_weights, router_loss2, gate_logits + +def calc_router_loss_and_logging( + self, + router_loss, + combine_weights, + dispatch_mask, + gate_logits, + gate_prob, + token_type_ids=None, + gate: nn.Layer, + dispatch_token_type_ids=None, + offload_helper=None, + layer_idx, + ): + """ + Calculate auxiliary losses and log statistics in fused expert case. + + Args: + router_loss: Base router loss + combine_weights: Combination weights + dispatch_mask: Dispatch mask + gate_logits: Gate logits + gate_prob: Gate probabilities + + Returns: + Tensor: Updated router loss + """ + assert gate_prob is not None + if token_type_ids is not None and gate.config.moe_use_hard_gate: # true + if not gate.weight.stop_gradient: + lm_tokens_mask = token_type_ids == 0 + if offload_helper is not None: + is_lm = offload_helper["lm_mask"][1] + else: + is_lm = lm_tokens_mask.any() + if is_lm: + dispatch_tokens_mask = ( + dispatch_token_type_ids == 0 + if dispatch_token_type_ids is not None + else None + ) + router_loss += _calc_router_loss( + ( + dispatch_mask[gate.experts_type_mask[0]] + if hasattr(gate, "experts_type_mask") + else dispatch_mask + ), + ( + gate_logits[:, gate.experts_type_mask[0]] + if hasattr(gate, "experts_type_mask") + else gate_logits + ), + ( + gate_prob[:, gate.experts_type_mask[0]] + if hasattr(gate, "experts_type_mask") + else gate_prob + ), + ( + gate.num_experts_list[0] + if hasattr(gate, "num_experts_list") + else gate.num_experts_tensor + ), + False, # self.group_experts, + layer_idx, + 0, + lm_tokens_mask, + dispatch_tokens_mask, + prefix="lm", + gate=gate + ) + # mm_tokens_mask = token_type_ids == 1 + # if offload_helper is not None: + # is_mm = offload_helper["mm_mask"][1] + # else: + # is_mm = mm_tokens_mask.any() + # if is_mm: + # dispatch_tokens_mask = ( + # dispatch_token_type_ids == 1 + # if dispatch_token_type_ids is not None + # else None + # ) + # router_loss += self._calc_router_loss( + # dispatch_mask[self.gate.experts_type_mask[1]], + # gate_logits[:, self.gate.experts_type_mask[1]], + # gate_prob[:, self.gate.experts_type_mask[1]], + # self.gate.num_experts_list[1], + # False, + # self.layer_idx, + # 1, + # mm_tokens_mask, + # dispatch_tokens_mask, + # prefix="mm", + # ) + + else: + router_loss += _calc_router_loss( + dispatch_mask, + gate_logits, + gate_prob, + gate.num_experts_tensor, + False,# self.group_experts, + layer_idx, + gate=gate + ) + + return router_loss + +def forward_experts( + dispatched_input, + experts: List[nn.Layer], + rank, + num_local_experts, + world_size, + ): + """ + Forward pass through experts sequentially. + + Args: + dispatched_input: Input tensor of shape [num_experts, capacity, dim] + + Returns: + Tensor: Expert outputs of shape [num_experts, capacity, dim] + """ + + # if not self.multimodal_experts: + # true_experts = self.experts[ + # self.rank + # * self.num_local_experts : (self.rank + 1) + # * self.num_local_experts + # ] + # else: + # true_experts = [] + # for i, num in enumerate(self.num_local_multimodal_experts): + # current_modal_experts = self.experts[ + # self.multimodal_expert_index[i] : self.multimodal_expert_index[ + # i + 1 + # ] + # ] + # true_experts.extend( + # current_modal_experts[self.rank * num : (self.rank + 1) * num] + # ) + true_experts = experts[ + rank + * num_local_experts : (rank + 1) + * num_local_experts + ] + + dispatched_input = dispatched_input.reshape( + [world_size, num_local_experts, -1, dispatched_input.shape[-1]] + ) # [e,1,c,m] + expert_outputs = [] + if isinstance(experts, nn.LayerList): + chunks = dispatched_input.transpose([1, 0, 2, 3]).contiguous().unbind(0) + assert len(chunks) == len(true_experts), (len(chunks), len(true_experts)) + for chunk, expert in zip(chunks, true_experts): + expert_outputs += [expert(chunk)] + else: + dispatched_input = dispatched_input.transpose([1, 0, 2, 3]) + dispatched_input.contiguous() + orig_shape = dispatched_input.shape + chunks = dispatched_input.reshape([orig_shape[0], -1, orig_shape[-1]]) + chunks = experts(chunks) + chunks = chunks.reshape(orig_shape[:-1] + [chunks.shape[-1]]).unbind(0) + expert_outputs += chunks + expert_output = paddle.stack(expert_outputs, axis=1) # [ecm] + return expert_output + +def gate_and_dispatch( + input, + token_type_ids=None, + gate: nn.Layer, + k, + config, + use_correction_bias, + moe_statics, + world_size, + num_local_experts + ): + """ + Calculate gate and dispatch inputs. + + Args: + input: Input tensor of shape [seq, dim] + + Returns: + tuple: (dispatched_input, combine_weights, dispatch_mask, + scatter_index, router_loss, gate_logits, gate_prob) + """ + seqlen, d_model = input.shape + args = () + if token_type_ids is not None: + token_type_ids = token_type_ids.reshape([-1]) + args = (token_type_ids,) + + ( + gate_logits, + capacity, + router_loss, + ) = gate(input, *args) + # if self.input_preprocess is not None: + # input, gate_logits = self.input_preprocess(input, gate_logits, capacity) + # capacity no use + # k = self.k + prob, max_prob = fused_gate_logits_process( + gate_logits=gate_logits, + token_type_ids=token_type_ids, + k=k, + gate=gate, + config=config) + + if "corr_bias" in inspect.signature(moe_gate_dispatch).parameters: + if use_correction_bias: + compat_args = (moe_statics.e_score_correction_bias[0],) + else: + compat_args = (None,) + else: + assert ( + not use_correction_bias + ), "correction bias not supported, rebuild moe-ops" + compat_args = () + + ( + dispatched_input, + combine_weights_unnorm, + scatter_index, + dispatch_mask, + _, + ) = moe_gate_dispatch( + input, prob, *compat_args, k=k, capacity=capacity, use_pad=True + ) + dispatched_input = dispatched_input.astype(input.dtype) + + dispatch_mask = paddle.diff(F.pad(dispatch_mask, (1, 0))) + if use_correction_bias: + if gate.config.multimodel_experts: + for i in range(len(moe_statics.expert_usage)): + moe_statics.expert_usage[i] += dispatch_mask[ + gate.experts_type_mask[i] + ].detach() + else: + moe_statics.expert_usage[0] += dispatch_mask.detach() + dispatched_input.stop_gradient = False + combine_weights_unnorm.stop_gradient = False + scatter_index.stop_gradient = True + dispatch_mask.stop_gradient = True + + scatter_index = scatter_index.transpose([1, 0]) # [k,s] ->[s,k] + # if self.group_experts: + # if max_prob is not None: + # if token_type_ids is not None: + # p = paddle.ones_like(combine_weights_unnorm.unsqueeze(-1)) + # p = paddle.scatter_nd_add( + # p, paddle.nonzero(token_type_ids == 0), -1 + max_prob + # ) + # else: + # p = max_prob + # combine_weights_unnorm = ( + # combine_weights_unnorm.unsqueeze(-1) * p + # ).squeeze(-1) + # # gate_prob 进行还原 + # prob = (prob.reshape([p.shape[0], k, -1]) * p).reshape([p.shape[0], -1]) + if gate.norm_gate_logits: + combine_weights = combine_weights_unnorm / paddle.clip( + combine_weights_unnorm.sum(-1, keepdim=True), min=1e-12 + ) + else: + combine_weights = combine_weights_unnorm + combine_weights = combine_weights.cast(dispatched_input.dtype) + + dispatched_input = dispatched_input.reshape( + [world_size * num_local_experts, capacity, d_model] + ) + dispatch_mask.stop_gradient = True + scatter_index.stop_gradient = True + return ( + dispatched_input, + combine_weights, + dispatch_mask, + scatter_index, + router_loss, + gate_logits, + prob, + ) + +def fused_gate_logits_process( + gate_logits, + token_type_ids=None, + offload_helper=None, + k, + gate:nn.Layer, + config, + +): + """ + Process and combine gate logits. + + Args: + gate_logits: Raw gate logits + + Returns: + tuple: (processed probabilities, max probabilities) + """ + experts_type_ids = gate.experts_type_ids + use_hard_gate = config.moe_use_hard_gate + max_prob = None + + if token_type_ids is not None and use_hard_gate: + if offload_helper is None: + offload_helper = dict() + lm_mask = token_type_ids == 0 + is_lm = lm_mask.any() + mm_mask = token_type_ids == 1 + # is_mm = mm_mask.any() + seq_lm = lm_mask.sum() + seq_mm = mm_mask.sum() + lm_mask = lm_mask.unsqueeze(1) & (experts_type_ids == 0).unsqueeze(0) + mm_mask = mm_mask.unsqueeze(1) & (experts_type_ids == 1).unsqueeze(0) + offload_helper["lm_mask"] = [lm_mask, is_lm, seq_lm] + # offload_helper["mm_mask"] = [mm_mask, is_mm, seq_mm] + + is_lm = offload_helper["lm_mask"][1] + prob = paddle.zeros_like(gate_logits) + # 处理 lm_prob + if is_lm: + lm_mask = offload_helper["lm_mask"][0] + seq_lm_cpu = offload_helper["lm_mask"][2] + lm_mask_nonzero = lm_mask.nonzero() + lm_partial_gate_logits = gate_logits.gather_nd(lm_mask_nonzero).reshape( + [seq_lm_cpu, -1] + ) + # if self.group_experts: + # lm_prob = self.gate.act( + # lm_partial_gate_logits.reshape( + # [lm_partial_gate_logits.shape[0], k, -1] + # ) + # ) + # max_prob = lm_prob.max(-1, keepdim=True) # [s_l, k, 1] + # lm_prob /= max_prob + # else: + lm_prob = gate.act(lm_partial_gate_logits) + prob = paddle.scatter_nd_add(prob, lm_mask_nonzero, lm_prob.flatten()) + # 处理 mm_prob + # is_mm = offload_helper["mm_mask"][1] + # if is_mm: + # mm_mask = offload_helper["mm_mask"][0] + # seq_mm_cpu = offload_helper["mm_mask"][2] + # mm_mask_nonzero = paddle.nonzero(mm_mask) + # mm_partial_gate_logits = gate_logits.gather_nd(mm_mask_nonzero).reshape( + # [seq_mm_cpu, -1] + # ) + # mm_prob = gate.act(mm_partial_gate_logits) + # prob = paddle.scatter_nd_add(prob, mm_mask_nonzero, mm_prob.flatten()) + else: + # 处理非硬门和不需要token_type_ids的情况 + # if self.group_experts: + # prob = self.gate.act(gate_logits.reshape([gate_logits.shape[0], k, -1])) + # max_prob = prob.max(-1, keepdim=True) + # prob /= max_prob + # prob = prob.reshape([prob.shape[0], -1]) + # else: + prob = gate.act(gate_logits) + return prob, max_prob \ No newline at end of file diff --git a/paddleformers/nn/moe/utils.py b/paddleformers/nn/moe/utils.py new file mode 100644 index 00000000000..3a4efa9179f --- /dev/null +++ b/paddleformers/nn/moe/utils.py @@ -0,0 +1,269 @@ +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import Tensor, _C_ops, framework, nn +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.distributed.communication import stream +from paddle.distributed.communication.group import Group +from paddle.distributed.fleet.utils import recompute +from paddle.incubate.nn.functional import moe_combine, moe_gate_dispatch +from paddleformers.utils.log import logger +from paddleformers.transformers.ernie4_5.sequence_parallel_utils import ScatterOp + +from paddleformers.transformers.ernie4_5.distributed.common_dist_utils import ( + AllGatherGroupOp, + ReduceScatterGroupOp, + all_gather_group, + get_async_loader, + hack_offload_wait, + reduce_scatter_group, +) + + +def manual_backward(f: Callable, is_first_fwd: bool, *args: List[Any]): + """ + Perform manual backward pass with gradient tracing control. + + Args: + f: Function to execute + is_first_fwd: Whether this is the first forward pass + args: Arguments for the function + + Returns: + tuple: (backward function, function outputs) + """ + tracer = framework._dygraph_tracer() + orig = tracer._has_grad + if not is_first_fwd: + tracer._has_grad = True # turn on grad trace so we can manual backward + + detached_args = detach_and_requires_grad_(*args) + detached_args_clone = [ + FakeClone.apply(a) if a is not None else None for a in detached_args + ] + out = f(*detached_args_clone) + if isinstance(out, list): + out = tuple(out) + elif not isinstance(out, tuple): + out = (out,) + + if is_first_fwd: + tracer._has_grad = orig + return None, out + + out_cached = [ + FakeClone.apply(o) for o in out if o is not None + ] # do not cache stop_gradient output + + for o in out_cached: + o._clear_dataptr() # free mem + tracer._has_grad = orig + + def bwd_f(*grad): + nonlocal out_cached, detached_args, f + grad = list(grad) + grad = [g for g in grad if g is not None] + assert grad and out_cached, (len(grad), len(out_cached)) + # out 中的 stop_graident 参数,也会收到 gradient,在这里过滤掉 + grad, out_cached = zip( + *[(g, o) for g, o in zip(grad, out_cached) if not o.stop_gradient] + ) + + assert len(grad) == len(out_cached), (len(grad), len(out_cached), f) + # out, grad = zip(*[(o, g) for o, g in zip(out, grad) if g is not None]) + paddle.autograd.backward(out_cached, grad) + return tuple([t.grad for t in detached_args if t is not None]) + + return bwd_f, out + +def combine_expert_output(expert_output, combine_weights, scatter_index): + """ + Combine expert outputs using combination weights. + + Args: + expert_output: Expert outputs [num_experts, capacity, dim] + combine_weights: Combination weights + scatter_index: Scatter indices + + Returns: + Tensor: Combined output [seqlen, dim] + """ + expert_output = expert_output.reshape( + [-1, expert_output.shape[-1]] + ) # [e*1,c,m] + combined_output = combining(expert_output, combine_weights, scatter_index) + + # if self.output_postprocess is not None: + # combined_output = self.output_postprocess(combined_output) + + return combined_output + +class GateCombine(PyLayer): + """ + Custom PyLayer for gate combination operations with backward pass. + """ + + @staticmethod + def forward(ctx, x, combine_weights, scatter_index): + """ + Forward pass for gate combination. + + Args: + x: Input tensor + combine_weights: Combination weights + scatter_index: Scatter indices + + Returns: + Tensor: Combined output + """ + ctx.x = x + ctx.combine_weights = combine_weights + ctx.scatter_index = scatter_index + ret = moe_combine(x, combine_weights, scatter_index) + return ret + + @staticmethod + def backward(ctx, grad_y, *_): + """ + Backward pass for gate combination. + + Args: + grad_y: Gradient of output [seqlen, hidden_size] + + Returns: + tuple: (grad_x, grad_combine_weight, None) + """ + grad_x, grad_combine_weight_helper = _C_ops.moe_combine_grad( + ctx.x, ctx.combine_weights, ctx.scatter_index, grad_y + ) + # grad_combine_weight_helper is the same shape with grad x [seqlen * K, dim] + # reduce the hidden shape + # TODO: implement reduce in cuda ops + grad_combine_weight = grad_combine_weight_helper.sum(-1) + return grad_x, grad_combine_weight.reshape(ctx.combine_weights.shape), None + +def combining(x, combine_weights, scatter_index, hard_gate=False): + """ + Fused version of combining operation. + + Args: + x: Input tensor [seq, dim] + combine_weights: Combination weights [s, k] + scatter_index: Scatter indices [k, s] + hard_gate: Whether to use hard gating + + Returns: + Tensor: Combined output [s, dim] + """ + if hard_gate: + x_gatherd = F.embedding(scatter_index, x) # [s,k,dim] + return x_gatherd.squeeze(-2) + if paddle.device.is_compiled_with_custom_device("npu"): + from ernie.fusion_ops.npu_fusion_ops import npu_combining + + ret = npu_combining(x, combine_weights, scatter_index) + else: + ret = GateCombine.apply(x, combine_weights, scatter_index) + ret.stop_gradient = False + return ret + +def _calc_router_loss( + self, + dispatch_mask, + gate_logits, + gate_prob, + num_experts, + use_group, + layer_idx, + token_type=None, + tokens_type_mask=None, + dispatch_tokens_mask=None, + prefix="", + gate: nn.Layer, +): + """ + Calculate router loss including auxiliary loss, z-loss and orthogonal loss. + + Args: + dispatch_mask: Dispatch mask + gate_logits: Gate logits + gate_prob: Gate probabilities + num_experts: Number of experts + use_group: Whether to use expert groups + layer_idx: Layer index + token_type: Token type + tokens_type_mask: Token type mask + dispatch_tokens_mask: Dispatch tokens mask + prefix: Prefix for logging + + Returns: + Tensor: Total router loss + """ + router_loss, l_aux, orthogonal_loss, zloss = 0.0, None, None, None + if gate.config.moe_aux_loss_lambda: + l_aux = gate._cal_aux_loss( + gate_prob, + dispatch_mask, + num_experts, + use_group, + tokens_type_mask, + dispatch_tokens_mask, + ) + router_loss += gate.moe_aux_loss_lambda[token_type or 0] * l_aux + else: + zero = paddle.to_tensor(0, dtype=paddle.float32) + router_loss += ( + zero * gate_prob[0, 0] + ) # must use gate prob to avoid zero pointer + if gate.config.moe_orthogonal_loss_lambda: + orthogonal_loss = gate._cal_orthogonal_loss(token_type, use_group) + router_loss += ( + gate.moe_orthogonal_loss_lambda[token_type or 0] * orthogonal_loss + ) + if gate.config.moe_z_loss_lambda: + zloss = gate._cal_z_loss(gate_logits, tokens_type_mask) + router_loss += gate.moe_z_loss_lambda[token_type or 0] * zloss + return router_loss + +class ReshardCombineWeight(PyLayer): + """ + Perform weights transform. + """ + + @staticmethod + def forward(ctx, input, group=None): + """Converts expert-partitioned weights to sequence-partitioned format. + + Args: + ctx: PyLayer context object + input (Tensor): Expert-wise partitioned weights [Seq, k] where: + - Non-local experts are zeroed out + - Seq: sequence dimension (may be sharded) + - k: expert capacity + group (ProcessGroup): Model parallel group (default:) + + Returns: + Tensor: Sequence-wise partitioned weights [Seq/n, k] via reduce-scatter + """ + + ctx.mask = input == 0.0 + ctx.group = group + return reduce_scatter_group(input, group=group) + + @staticmethod + def backward(ctx, grad): + """Reconstructs expert-partitioned gradients from sequence-wise gradients. + + Args: + grad (Tensor): Sequence-wise partitioned gradients [Seq/n, k] + + Returns: + Tensor: Expert-wise partitioned gradients [Seq, k] with zeros for + non-local experts + """ + gathered = all_gather_group(grad, group=ctx.group) + return gathered.masked_fill( + ctx.mask, + 0.0, + ) \ No newline at end of file