diff --git a/paddleformers/nn/moe/abstract.py b/paddleformers/nn/moe/abstract.py new file mode 100644 index 00000000000..fe3c164996c --- /dev/null +++ b/paddleformers/nn/moe/abstract.py @@ -0,0 +1,20 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + + +class MOELayerBase(nn.Layer): + pass diff --git a/paddleformers/nn/moe/all_gather.py b/paddleformers/nn/moe/all_gather.py new file mode 100644 index 00000000000..f71fd6b6372 --- /dev/null +++ b/paddleformers/nn/moe/all_gather.py @@ -0,0 +1,694 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, Optional + +import paddle +import paddle.distributed as dist +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.distributed.communication.group import _get_global_group +from paddle.distributed.fleet.utils import recompute + +from .utils import manual_backward + + +def allgather_async(input, group=None): + """Perform asynchronous All-Gather operation for model parallelism. + + Args: + input (Tensor): Local tensor to gather (shape: [N, ...]) + group (ProcessGroup): Model parallel group (default: auto-detected) + + Returns: + tuple: (output_tensor, communication_task) + output_tensor: Pre-allocated buffer with shape [N*K, ...] (K=group_size) + communication_task: Paddle communication task handle for synchronization + """ + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + if parallelism == 1: + return input.clone(), None + output_shape = input.shape + output_shape[0] = output_shape[0] * parallelism + output = paddle.empty(shape=output_shape, dtype=input.dtype) + task = dist.stream.all_gather(output, input, group=group, use_calc_stream=False, sync_op=False) + return output, task + + +def reduce_scatter_async(input, group=None): + """Perform asynchronous reduce-scatter operation for distributed training. + + Args: + input (Tensor): Local tensor to reduce (shape: [N*K, ...], N=group_size) + group (ProcessGroup): Communication group (default: model parallel group) + + Returns: + tuple: (output_tensor, communication_task) + output_tensor: Scattered tensor portion with shape [K, ...] + communication_task: Handle for synchronizing the async operation + """ + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + if parallelism == 1: + return input.clone(), None + output_shape = input.shape + assert ( + input.shape[0] % parallelism == 0 + ), f"Input sequence length {input.shape[0]} can't be divided exactly by sequence parallelism {parallelism}" + output_shape[0] = output_shape[0] // parallelism + output = paddle.empty(shape=output_shape, dtype=input.dtype) + task = dist.stream.reduce_scatter( + output, + input, + op=dist.ReduceOp.SUM, + group=group, + use_calc_stream=False, + sync_op=False, + ) + return output, task + + +class AllGatherAsync(PyLayer): + """ + Perform async allgather. + """ + + @staticmethod + def forward(ctx, input, *fn_args, group=None, fn=None, is_first_fwd=False): + """Forward pass with integrated communication-computation overlap. + + Args: + ctx: PyLayer context object + input (Tensor): Sharded input tensor [s/n, b, h] + *fn_args: Arguments for custom forward function + group: Model parallel process group + fn: Custom forward function to execute after communication + is_first_fwd: Flag indicating first forward pass in sequence + + Returns: + tuple: (gathered_tensor, ...custom_forward_outputs) + """ + ctx.group = group + if dist.get_world_size(group) <= 1: + ctx.bwf, fn_out = manual_backward(fn, is_first_fwd, *fn_args) + return (input,) + fn_out + out, task = allgather_async(input, group=group) + ctx.bwf, fn_out = manual_backward(fn, is_first_fwd, *fn_args) + task and task.wait() + return (out,) + fn_out + + @staticmethod + def backward(ctx, grad, *fn_out_grads): + """Backward pass with gradient synchronization. + + Args: + ctx: PyLayer context with stored communication group + grad (Tensor): Full gradient tensor [s, b, h] + *fn_out_grads: Gradients from custom forward outputs + + Returns: + tuple: (scattered_grad, ...custom_arg_grads) + """ + if dist.get_world_size(ctx.group) <= 1: + fn_args_grads = ctx.bwf(*fn_out_grads) + return (grad,) + fn_args_grads + + grad, task = reduce_scatter_async(grad, group=ctx.group) + fn_args_grads = ctx.bwf(*fn_out_grads) + task and task.wait() + return (grad,) + fn_args_grads + + +class AlltoAllSmart(paddle.autograd.PyLayer): + """ + Perform dispatch inputs alltoall. + """ + + @staticmethod + def forward( + ctx, + *inputs, + router_loss_fn: Optional[Callable], + forward_func_dict: Optional[Dict[int, Callable]], + local_expert_id=None, + send_rank_global=None, + recv_rank_global=None, + num_local_experts=None, + capacity=None, + use_padding=True, + expert_num_global=None, + is_first_fwd=None, + group=None, + recv_size=None, + send_counts=None, + recv_counts=None, + send_counts_num=None, + recv_counts_num=None, + ): + """Implements batched point-to-point communication with expert computation overlap. + + Functional Behavior: + - Performs distributed All-to-All communication with variable message sizes + - Overlaps expert forward computation with communication operations + - Calculates router loss for dynamic expert selection + - Handles padding/compression for irregular tensor shapes + + Key Operations: + 1. Prepare communication buffers based on send/recv counts + 2. Launch asynchronous All-to-All operations + 3. Execute expert forward functions in parallel with communication + 4. Calculate routing loss and prepare gradient masks + + Args: + ctx: PyLayer context object + *inputs: Variable-length expert inputs (Tensor[...]) + router_loss_fn: Routing loss calculator function + forward_func_dict: Expert-specific forward functions {expert_id: callable} + local_expert_id: Tensor indicating local expert assignments + send_rank_global: Global ranks for sending data + recv_rank_global: Global ranks for receiving data + num_local_experts: Number of experts per device + capacity: Maximum tokens per expert + use_padding: Enable padding for fixed-size buffers + expert_num_global: Global expert count + is_first_fwd: Flag for activation checkpointing + group: Process group for communication + recv_size: Precomputed receive buffer size + send_counts: Per-expert send counts [num_local_experts, world_size] + recv_counts: Per-expert recv counts [num_local_experts, world_size] + send_counts_num: Aggregated send expert + recv_counts_num: Aggregated recv counts per expert + + Returns: + tuple: (output_tensor, router_loss, gradient_mask) + """ + if group is None: + group = _get_global_group() + router_loss_args = inputs[num_local_experts:] + inputs = inputs[:num_local_experts] + + ctx.group = group + ctx.use_padding = use_padding + ctx.num_local_experts = num_local_experts + ctx.input_shape = [i.shape if i is not None else None for i in inputs] + + this_rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + capacity = len(send_rank_global) // world_size // num_local_experts + ctx.capacity = capacity + assert len(local_expert_id) == len(recv_rank_global), ( + len(local_expert_id), + len(recv_rank_global), + ) + + for i in inputs: + if i is not None: + input_dtype = i.dtype + input_shape = i.shape + break + else: + raise RuntimeError("all inputs are None") + + output = paddle.zeros([recv_size] + input_shape[1:], dtype=input_dtype) + output_ptr = 0 + + tasks = [] + dummy_input = paddle.empty([0] + input_shape[1:], dtype=input_dtype) + ctx.dummy_input = dummy_input + ctx.bw_funcs = {} + + for i_local_expert in range(num_local_experts): + send_count = send_counts[i_local_expert] + recv_count = recv_counts[i_local_expert] + assert len(recv_count) == len(send_count) == (world_size), ( + len(recv_count), + len(send_count), + ) + + if send_counts_num[i_local_expert] > 0: + input_local_expert = inputs[i_local_expert].slice((0,), 0, send_counts_num[i_local_expert]) + if forward_func_dict is not None: + input_local_expert.stop_gradient = False + bwf, (input_local_expert,) = manual_backward( + forward_func_dict[i_local_expert], + is_first_fwd, + input_local_expert, + ) + ctx.bw_funcs[i_local_expert] = bwf + + if input_local_expert is None: + input_local_expert = dummy_input + input_local_expert.stop_gradient = True + else: + input_local_expert = dummy_input + if recv_counts_num[i_local_expert] > 0: + # When FLAGS_use_stride_kernel=0, tensor.slice(...) returns a + # new tensor instead of a view, causing in-place assignment to fail. + # tensor._slice ensures it always returns a view. + # See: + # https://github.com/PaddlePaddle/Paddle/blob/release/3.1/paddle/phi/core/dense_tensor_impl.cc#L299 + output_local_expert = output._slice(output_ptr, (output_ptr + recv_counts_num[i_local_expert])) + else: + output_local_expert = dummy_input + + output_ptr += recv_counts_num[i_local_expert] + + if group.nranks <= 1: + output_local_expert[:] = input_local_expert[:] + else: + tasks.append( + dist.stream.alltoall_single( + output_local_expert, + input_local_expert, + recv_count, + send_count, + group=group, + sync_op=False, + use_calc_stream=False, + ) + ) + ctx.router_loss_bwfn, (router_loss,) = manual_backward(router_loss_fn, is_first_fwd, *router_loss_args) + with paddle.no_grad(): + recv_mask = (recv_rank_global == this_rank).astype(send_rank_global.dtype) + if ctx.use_padding: + recv_mask_alltoall_out = ( + recv_mask.reshape([-1, num_local_experts, capacity]).transpose([1, 0, 2]).reshape([-1]) + ) + distributed_input_to_alltoall_out = paddle.maximum( + recv_mask_alltoall_out.cumsum() - 1, + paddle.zeros([1], dtype=recv_mask_alltoall_out.dtype), + ) + distributed_input_to_alltoall_out = ( + distributed_input_to_alltoall_out.view([num_local_experts, -1, capacity]) + .transpose([1, 0, 2]) + .reshape([-1]) + ) + else: + recv_mask_alltoall_out = recv_mask.split(expert_num_global) # h->d copy break overlap + recv_mask_alltoall_out = [ + recv_mask_alltoall_out[(iexpert % world_size) * num_local_experts + (iexpert // world_size)] + for iexpert in range(world_size * num_local_experts) + ] + alltoall_shape = [i.shape[0] for i in recv_mask_alltoall_out] + + recv_mask_alltoall_out = paddle.concat(recv_mask_alltoall_out, 0) + distributed_input_to_alltoall_out = paddle.maximum( + recv_mask_alltoall_out.cumsum() - 1, + paddle.zeros([1], dtype=recv_mask_alltoall_out.dtype), + ) + distributed_input_to_alltoall_out = distributed_input_to_alltoall_out.split(alltoall_shape) + + distributed_input_to_alltoall_out = paddle.concat( + [ + distributed_input_to_alltoall_out[ + (iexpert % num_local_experts) * world_size + (iexpert // num_local_experts) + ] + for iexpert in range(world_size * num_local_experts) + ], + 0, + ) + + distributed_input_to_alltoall_out.stop_gradient = True + for t in tasks: + t and t.wait() + ctx.send_counts = send_counts + ctx.recv_counts = recv_counts + return output, router_loss, distributed_input_to_alltoall_out + + @staticmethod + def backward( + ctx, + out_grad, + d_routerloss, + _, # scatter-idx no grad + ): + """Performs distributed gradient propagation for expert-parallel models. + + Functional Behavior: + - Distributes output gradients via reverse All-to-All communication + - Computes expert-specific gradients using stored backward functions + - Aggregates routing loss gradients + + Key Operations: + 1. Prepare gradient buffers based on forward pass metadata + 2. Execute reverse All-to-All communication + 3. Apply expert-specific backward computations + 4. Combine gradients from all sources + + Args: + ctx: Context object storing forward pass information + out_grad (Tensor): Gradient from downstream layers + d_routerloss (Tensor): Routing loss gradient + _: Ignored placeholder + + Returns: + tuple: Combined gradients (expert gradients + router loss gradients) + """ + + grads = [paddle.zeros(s, dtype=out_grad.dtype) if s is not None else None for s in ctx.input_shape] + assert len(grads) == ctx.num_local_experts + out_ptr = 0 + tasks = [] + tmp_g = [] + send_counts_num = ctx.send_counts.sum(-1) + recv_counts_num = ctx.recv_counts.sum(-1) + out_grad = out_grad.contiguous() + for i_local_expert in range(ctx.num_local_experts): + send_count = ctx.send_counts[i_local_expert] + recv_count = ctx.recv_counts[i_local_expert] + if recv_counts_num[i_local_expert] > 0: + out_g = out_grad.slice((0,), out_ptr, out_ptr + recv_counts_num[i_local_expert]) + else: + out_g = ctx.dummy_input # paddle.empty([0,]+out_grad.shape[1:], dtype=out_grad.dtype) + if send_counts_num[i_local_expert] > 0: + # When FLAGS_use_stride_kernel=0, tensor.slice(...) returns a + # new tensor instead of a view, causing in-place assignment to fail. + # tensor._slice ensures it always returns a view. + # See: + # https://github.com/PaddlePaddle/Paddle/blob/release/3.1/paddle/phi/core/dense_tensor_impl.cc#L299 + g = grads[i_local_expert]._slice(0, send_counts_num[i_local_expert]) + else: + g = ctx.dummy_input + tmp_g.append(g) + out_ptr += recv_counts_num[i_local_expert] + if ctx.group.nranks <= 1: + g[:] = out_g[:] + else: + task = dist.stream.alltoall_single( + g, + out_g, + send_count, + recv_count, + group=ctx.group, + sync_op=False, + use_calc_stream=False, + ) + tasks.append(task) + router_fn_args_grad = ctx.router_loss_bwfn(d_routerloss) + + for i_local_expert, t in enumerate(tasks): + t and t.wait() + send_cnt = send_counts_num[i_local_expert] + if send_cnt > 0 and ctx.bw_funcs: + (g,) = ctx.bw_funcs[i_local_expert](tmp_g[i_local_expert]) + grads[i_local_expert][:send_cnt] = g + + grads = [g for g in grads if g is not None] + return tuple(grads) + tuple(router_fn_args_grad) + + +class AlltoAllSmartXPU(paddle.autograd.PyLayer): + """ + Perform dispatch inputs alltoall. (XPU VERSION) + """ + + @staticmethod + def forward( + ctx, + *inputs, + router_loss_fn: Optional[Callable], + forward_func_dict: Optional[Dict[int, Callable]], + local_expert_id=None, + send_rank_global=None, + recv_rank_global=None, + num_local_experts=None, + capacity=None, + use_padding=True, + expert_num_global=None, + is_first_fwd=None, + group=None, + recv_size=None, + send_counts=None, + recv_counts=None, + send_counts_num=None, + recv_counts_num=None, + ): + if group is None: + group = _get_global_group() + router_loss_args = inputs[num_local_experts:] + inputs = inputs[:num_local_experts] + + ctx.group = group + ctx.use_padding = use_padding + ctx.num_local_experts = num_local_experts + ctx.input_shape = [i.shape if i is not None else None for i in inputs] + ctx.send_counts = send_counts + ctx.recv_counts = recv_counts + ctx.send_counts_num = send_counts_num + ctx.recv_counts_num = recv_counts_num + + world_size = dist.get_world_size(group) + this_rank = dist.get_rank(group) + if use_padding and capacity is None: + capacity = len(send_rank_global) // world_size // num_local_experts + + for i in inputs: + if i is not None: + input_dtype = i.dtype + input_shape = i.shape + break + else: + first_expert = forward_func_dict[0] + input_dtype = first_expert.up_gate_proj.weight.dtype + hidden_size = first_expert.up_gate_proj.weight.shape[0] + input_shape = [0, hidden_size] + + dummy_input = paddle.empty([0] + input_shape[1:], dtype=input_dtype) + ctx.dummy_input = dummy_input + ctx.bw_funcs = {} + + processed_inputs = [] + no_tokens_expert_outputs = [] + + for i_local_expert in range(num_local_experts): + if send_counts_num[i_local_expert] > 0: + input_local_expert = inputs[i_local_expert].slice((0,), 0, send_counts_num[i_local_expert]) + if forward_func_dict is not None: + input_local_expert.stop_gradient = False + bwf, (processed_input,) = manual_backward( + forward_func_dict[i_local_expert], + is_first_fwd, + input_local_expert, + ) + ctx.bw_funcs[i_local_expert] = bwf + processed_input.stop_gradient = True + else: + processed_input = input_local_expert + processed_inputs.append(processed_input) + elif forward_func_dict is not None: + expert_func = forward_func_dict[i_local_expert] + fake_chunk = paddle.zeros( + [1, expert_func.up_gate_proj.weight.shape[0]], + dtype=expert_func.up_gate_proj.weight.dtype, + ) + if expert_func.training: + fake_chunk.stop_gradient = False + + _, (expert_out,) = manual_backward(expert_func, is_first_fwd, fake_chunk) + + no_tokens_expert_outputs.append(expert_out * 0.0) + + all_processed_inputs = paddle.concat(processed_inputs, axis=0) if processed_inputs else dummy_input + + if no_tokens_expert_outputs: + if all_processed_inputs.shape[0] > 0: + all_processed_inputs[0] = all_processed_inputs[0] + sum(no_tokens_expert_outputs) + else: + router_loss_args = list(router_loss_args) + router_loss_args[0] = router_loss_args[0] + sum(no_tokens_expert_outputs).mean() * 0.0 + + in_tensors_by_rank = [[] for _ in range(world_size)] + processed_input_ptr = 0 + for i_local_expert in range(num_local_experts): + num_tokens = send_counts_num[i_local_expert] + if num_tokens > 0: + expert_input = all_processed_inputs.slice([0], processed_input_ptr, processed_input_ptr + num_tokens) + processed_input_ptr += num_tokens + splits = expert_input.split(send_counts[i_local_expert].tolist(), axis=0) + for j_rank in range(world_size): + in_tensors_by_rank[j_rank].append(splits[j_rank]) + + in_tensor_list = [paddle.concat(tensors, 0) if tensors else dummy_input for tensors in in_tensors_by_rank] + + all_to_all_input = paddle.concat(in_tensor_list, 0) + send_counts_for_api = [t.shape[0] for t in in_tensor_list] + + recv_counts_tensor = paddle.to_tensor(recv_counts) + recv_counts_for_api = [int(recv_counts_tensor[:, j_rank].sum()) for j_rank in range(world_size)] + temp_output = paddle.empty([recv_size.item()] + input_shape[1:], dtype=input_dtype) + + if group.nranks <= 1: + task = None + if all_to_all_input.shape[0] > 0: + temp_output[:] = all_to_all_input[:] + else: + task = dist.stream.alltoall_single( + temp_output, + all_to_all_input, + recv_counts_for_api, + send_counts_for_api, + group=group, + sync_op=False, + use_calc_stream=False, + ) + + ctx.router_loss_bwfn, (router_loss,) = manual_backward(router_loss_fn, is_first_fwd, *router_loss_args) + with paddle.no_grad(): + recv_mask = (recv_rank_global == this_rank).astype(send_rank_global.dtype) + if ctx.use_padding: + recv_mask_alltoall_out = ( + recv_mask.reshape([-1, num_local_experts, capacity]).transpose([1, 0, 2]).reshape([-1]) + ) + distributed_input_to_alltoall_out = paddle.maximum( + recv_mask_alltoall_out.cumsum() - 1, + paddle.zeros([1], dtype=recv_mask_alltoall_out.dtype), + ) + distributed_input_to_alltoall_out = ( + distributed_input_to_alltoall_out.view([num_local_experts, -1, capacity]) + .transpose([1, 0, 2]) + .reshape([-1]) + ) + else: + recv_mask_alltoall_out = recv_mask.split(expert_num_global) + recv_mask_alltoall_out = [ + recv_mask_alltoall_out[(iexpert % world_size) * num_local_experts + (iexpert // world_size)] + for iexpert in range(world_size * num_local_experts) + ] + alltoall_shape = [i.shape[0] for i in recv_mask_alltoall_out] + recv_mask_alltoall_out = paddle.concat(recv_mask_alltoall_out, 0) + distributed_input_to_alltoall_out = paddle.maximum( + recv_mask_alltoall_out.cumsum() - 1, + paddle.zeros([1], dtype=recv_mask_alltoall_out.dtype), + ) + distributed_input_to_alltoall_out = distributed_input_to_alltoall_out.split(alltoall_shape) + distributed_input_to_alltoall_out = paddle.concat( + [ + distributed_input_to_alltoall_out[ + (iexpert % num_local_experts) * world_size + (iexpert // num_local_experts) + ] + for iexpert in range(world_size * num_local_experts) + ], + 0, + ) + + distributed_input_to_alltoall_out.stop_gradient = True + + if task is not None: + task.wait() + + temp_output_splits_by_src_rank = temp_output.split(recv_counts_for_api, 0) + chunks_by_expert = [[] for _ in range(num_local_experts)] + for j_rank in range(world_size): + data_from_j = temp_output_splits_by_src_rank[j_rank] + expert_chunks_from_j = data_from_j.split(recv_counts[:, j_rank].tolist(), 0) + for i_expert in range(num_local_experts): + chunks_by_expert[i_expert].append(expert_chunks_from_j[i_expert]) + + output_chunks = [] + for i_expert in range(num_local_experts): + if recv_counts_num[i_expert] > 0: + output_chunks.append(paddle.concat(chunks_by_expert[i_expert], 0)) + output = paddle.concat(output_chunks, 0) if output_chunks else dummy_input + + return output, router_loss, distributed_input_to_alltoall_out + + @staticmethod + def backward( + ctx, + out_grad, + d_routerloss, + _, # scatter-idx no grad + ): + world_size = dist.get_world_size(ctx.group) + num_local_experts = ctx.num_local_experts + dummy_input = ctx.dummy_input + out_grad = out_grad.contiguous() + + send_counts_bw = ctx.recv_counts + send_counts_num_bw = ctx.recv_counts_num + in_tensors_by_rank_bw = [[] for _ in range(world_size)] + grad_ptr = 0 + for i_expert in range(num_local_experts): + num_tokens = send_counts_num_bw[i_expert] + if num_tokens > 0: + expert_grad = out_grad.slice([0], grad_ptr, grad_ptr + num_tokens) + grad_ptr += num_tokens + splits = expert_grad.split(send_counts_bw[i_expert].tolist(), 0) + for j_rank in range(world_size): + in_tensors_by_rank_bw[j_rank].append(splits[j_rank]) + in_tensor_list_bw = [ + paddle.concat(tensors, 0) if tensors else dummy_input for tensors in in_tensors_by_rank_bw + ] + + all_to_all_grad_input = paddle.concat(in_tensor_list_bw, 0) + send_counts_bw_for_api = [t.shape[0] for t in in_tensor_list_bw] + + recv_counts_bw = ctx.send_counts + recv_counts_tensor_bw = paddle.to_tensor(recv_counts_bw) + recv_counts_bw_for_api = [int(recv_counts_tensor_bw[:, j_rank].sum()) for j_rank in range(world_size)] + total_output_grad_size = int(ctx.send_counts_num.sum()) + temp_grad_output = paddle.empty([total_output_grad_size] + list(out_grad.shape[1:]), dtype=out_grad.dtype) + + if ctx.group.nranks <= 1: + task = None + if all_to_all_grad_input.shape[0] > 0: + temp_grad_output[:] = all_to_all_grad_input[:] + else: + task = dist.stream.alltoall_single( + temp_grad_output, + all_to_all_grad_input, + recv_counts_bw_for_api, + send_counts_bw_for_api, + group=ctx.group, + sync_op=False, + use_calc_stream=False, + ) + + router_fn_args_grad = ctx.router_loss_bwfn(d_routerloss) + + if task is not None: + task.wait() + + temp_grad_output_splits = temp_grad_output.split(recv_counts_bw_for_api, 0) + grad_chunks_by_expert = [[] for _ in range(num_local_experts)] + for j_rank in range(world_size): + data_from_j = temp_grad_output_splits[j_rank] + expert_chunks_from_j = data_from_j.split(recv_counts_bw[:, j_rank].tolist(), 0) + for i_expert in range(num_local_experts): + grad_chunks_by_expert[i_expert].append(expert_chunks_from_j[i_expert]) + + grads = [paddle.zeros(s, dtype=out_grad.dtype) if s is not None else None for s in ctx.input_shape] + for i_expert in range(num_local_experts): + num_tokens = ctx.send_counts_num[i_expert] + if num_tokens > 0: + reconstructed_grad = paddle.concat(grad_chunks_by_expert[i_expert], 0) + if i_expert in ctx.bw_funcs: + (final_grad,) = ctx.bw_funcs[i_expert](reconstructed_grad) + else: + final_grad = reconstructed_grad + if grads[i_expert] is not None: + grads[i_expert][:num_tokens] = final_grad + + grads = [g for g in grads if g is not None] + return tuple(grads) + tuple(router_fn_args_grad) + + +# Conditionally select the AlltoAllSmart implementation +if paddle.is_compiled_with_xpu(): + AlltoAllSmart = AlltoAllSmartXPU diff --git a/paddleformers/nn/moe/all_to_all.py b/paddleformers/nn/moe/all_to_all.py new file mode 100644 index 00000000000..4cb3e9d8faf --- /dev/null +++ b/paddleformers/nn/moe/all_to_all.py @@ -0,0 +1,134 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.distributed as dist +from paddle import Tensor +from paddle.autograd import PyLayer +from paddle.distributed.communication import stream + +from .utils import manual_backward + + +class AlltoAll(PyLayer): + """ + Custom PyLayer for All-to-All communication with backward pass. + """ + + @staticmethod + def forward(ctx, x, group, sync_op=True): + """ + Perform All-to-All communication in the group. + + Args: + x: Input tensor + group: Communication group + sync_op: Whether to perform synchronous operation + + Returns: + Tensor: Output tensor + """ + ctx.group = group + if dist.get_world_size(group) <= 1: + return x + output = paddle.empty_like(x) + output.stop_gradient = False + task = stream.alltoall_single(output, x, None, None, group, sync_op=sync_op, use_calc_stream=sync_op) + if not sync_op: + return output, task + else: + return output + + @staticmethod + def backward(ctx, *dx): + """ + Backward pass for All-to-All communication. + + Args: + dx: Gradient tensor + + Returns: + Tensor: Gradient after backward All-to-All + """ + return AlltoAll.apply(*dx, group=ctx.group) + + +class AlltoAllAsync(PyLayer): + """ + Custom PyLayer for asynchronous All-to-All communication. + """ + + @staticmethod + def forward(ctx, x, *fn_args, group=None, fn=None, is_first_fwd=False): + """ + Asynchronous All-to-All communication with function execution. + + Args: + x: Input tensor + fn_args: Arguments for the function + group: Communication group + fn: Function to execute + is_first_fwd: Whether this is the first forward pass + + Returns: + tuple: (output tensor, function outputs) + """ + assert fn is not None, "use AlltoAll no async" + ctx.group = group + if dist.get_world_size(group) <= 1: + ctx.bwf, fn_out = manual_backward(fn, is_first_fwd, *fn_args) + return (x,) + fn_out + x_out = paddle.empty_like(x) + x_out.stop_gradient = False + task = stream.alltoall_single( + x_out, + x, + None, + None, + group, + sync_op=False, + ) + ctx.bwf, fn_out = manual_backward(fn, is_first_fwd, *fn_args) + task.wait() + return (x_out,) + fn_out + + @staticmethod + def backward(ctx, dx_out, *fn_out_grads): + """ + Backward pass for asynchronous All-to-All. + + Args: + dx_out: Gradient of output + fn_out_grads: Gradients of function outputs + + Returns: + tuple: (gradient tensor, function argument gradients) + """ + if dist.get_world_size(ctx.group) <= 1: + fn_args_grads = ctx.bwf(*fn_out_grads) + return (dx_out,) + fn_args_grads + + dx = paddle.empty_like(dx_out) + dx.stop_gradient = False + task = stream.alltoall_single( + dx, + dx_out, + None, + None, + ctx.group, + sync_op=False, + ) + fn_args_grads = ctx.bwf(*fn_out_grads) + task.wait() + return (dx,) + fn_args_grads diff --git a/paddleformers/nn/moe/fused_a2a.py b/paddleformers/nn/moe/fused_a2a.py new file mode 100644 index 00000000000..7b5fa09c9e0 --- /dev/null +++ b/paddleformers/nn/moe/fused_a2a.py @@ -0,0 +1,216 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 DeepSeek +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import paddle.distributed.communication.deep_ep as deep_ep + + HAVE_DEEP_EP = True +except ImportError: + HAVE_DEEP_EP = False + +import paddle +from paddle.autograd import PyLayer +from paddle.distributed.communication.group import Group + +_buffer = None + + +def get_hidden_bytes(x: paddle.Tensor) -> int: + """Calculate the number of hidden bytes for a tensor. + + Args: + x (paddle.Tensor): Input tensor + + Returns: + int: Number of hidden bytes + """ + return x.shape[1] * max(x.element_size(), 2) + + +def get_buffer(group: Group, hidden_bytes: int): + """Get or create a buffer for all-to-all communication. + + Args: + group (paddle.distributed.ProcessGroup): Process group for communication + hidden_bytes (int): Number of hidden bytes needed + + Returns: + Buffer: Communication buffer + """ + global _buffer + num_nvl_bytes, num_rdma_bytes = 0, 0 + for config in ( + deep_ep.Buffer.get_dispatch_config(group.world_size), + deep_ep.Buffer.get_combine_config(group.world_size), + ): + # Split long line for PEP8 compliance + num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.world_size), num_nvl_bytes) + num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.world_size), num_rdma_bytes) + + # Allocate buffer if not existed or not enough buffer + # NOTES: the adaptive routing configuration of the network **must be off** + if ( + _buffer is None + or _buffer.group != group + or _buffer.num_nvl_bytes < num_nvl_bytes + or _buffer.num_rdma_bytes < num_rdma_bytes + ): + _buffer = deep_ep.Buffer(group, num_nvl_bytes, num_rdma_bytes) + return _buffer + + +class FusedDispatch(PyLayer): + """Fused dispatch operation for MoE routing combining computation and communication.""" + + @staticmethod + def forward(ctx, x, token_indices, token_probs, num_experts, group, previous_event=None): + """Forward pass of fused dispatch.""" + # Calculate layout before actual dispatch + buffer = get_buffer(group, get_hidden_bytes(x)) + ( + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + previous_event, + ) = buffer.get_dispatch_layout( + token_indices, + num_experts, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, + ) + + # Do MoE dispatch + # NOTES: the CPU will wait for GPU's signal to arrive, + # so this is not compatible with CUDA graph + ( + recv_x, + recv_token_indices, + recv_token_probs, + num_recv_tokens_per_expert_list, + handle, + event, + ) = buffer.dispatch( + x, + topk_idx=token_indices, + topk_weights=token_probs.cast(paddle.float32), + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, + ) + + ctx.group = group + ctx.handle = handle + ctx.event = event + tokens_per_expert = paddle.to_tensor(num_recv_tokens_per_expert_list) + + states = dict() + states["dispatched_indices"] = recv_token_indices + states["tokens_per_expert"] = tokens_per_expert + states["handle"] = handle + + return recv_x, recv_token_probs, states + + @staticmethod + def backward(ctx, grad_output, grad_token_probs): + """Backward pass of fused dispatch.""" + buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) + handle = ctx.handle + + grad_x, grad_token_probs, event = buffer.combine( + grad_output.contiguous(), + handle, + topk_weights=grad_token_probs.cast(paddle.float32), + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, + ) + return grad_x, None, grad_token_probs + + +class FusedCombine(PyLayer): + """Fused combine operation for MoE output combining computation and communication.""" + + @staticmethod + def forward(ctx, x, group, states, previous_event=None): + """Forward pass of fused combine.""" + handle = states["handle"] + buffer = get_buffer(group, get_hidden_bytes(x)) + combined_x, _, event = buffer.combine( + x, handle=handle, async_finish=False, previous_event=None, allocate_on_comm_stream=False + ) + ctx.handle = handle + ctx.group = group + ctx.previous_event = previous_event + + return combined_x + + @staticmethod + def backward(ctx, grad_output): + """Backward pass of fused combine.""" + buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) + grad_x, _, _, _, _, event = buffer.dispatch( + grad_output.contiguous(), + handle=ctx.handle, + previous_event=ctx.previous_event, + async_finish=False, + allocate_on_comm_stream=False, + ) + return grad_x + + +if HAVE_DEEP_EP: + + def fused_dispatch(x, token_indices, token_probs, num_experts, group: Group, previous_event=None): + """Perform fused dispatch operation if deep_ep is available. + + Args: + x: Input tensor [num_tokens, hidden_size] + token_indices: Token routing indices [num_tokens, topk] + token_probs: Token routing probabilities [num_tokens, topk] + num_experts: Number of experts + group: Process group + previous_event: Previous CUDA event + + Returns: + Result of FusedDispatch + """ + return FusedDispatch.apply(x.contiguous(), token_indices, token_probs, num_experts, group, previous_event) + + def fused_combine(x, group, handle, previous_event=None): + """Perform fused combine operation if deep_ep is available. + + Args: + x: Input tensor + group: Process group + handle: Communication handle + previous_event: Previous CUDA event + + Returns: + Result of FusedCombine + """ + states = dict() + states["handle"] = handle + return FusedCombine.apply(x, group, states, previous_event) + +else: + fused_dispatch = None + fused_combine = None diff --git a/paddleformers/nn/moe/moe_allgather_layer.py b/paddleformers/nn/moe/moe_allgather_layer.py new file mode 100644 index 00000000000..8c21e0998a0 --- /dev/null +++ b/paddleformers/nn/moe/moe_allgather_layer.py @@ -0,0 +1,843 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +moe_layer_all_gather +""" + +import inspect +from typing import Callable, Dict, List, Optional, Tuple + +import paddle +import paddle.distributed as dist +from paddle import framework, nn +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.distributed.communication.group import Group, _get_global_group +from paddle.distributed.fleet.utils import recompute +from paddle.incubate.nn.functional import ( + build_src_rank_and_local_expert_id, + expand_modality_expert_id, + moe_gate_dispatch_partial_nosoftmaxtopk, +) +from paddle.incubate.tensor.manipulation import async_offload + +from paddleformers.peft.lora.lora_quantization_layers import QuantizationLoRALinear +from paddleformers.utils.log import logger + +from .all_gather import AllGatherAsync, AlltoAllSmart +from .moe_alltoall_layer import MOEAlltoAllLayer +from .utils import ( + AllGatherGroupOp, + ReduceScatterGroupOp, + all_gather_group, + get_async_loader, + hack_offload_wait, + reduce_scatter_group, +) + + +class ReshardCombineWeight(PyLayer): + """ + Perform weights transform. + """ + + @staticmethod + def forward(ctx, input, group=None): + """Converts expert-partitioned weights to sequence-partitioned format. + + Args: + ctx: PyLayer context object + input (Tensor): Expert-wise partitioned weights [Seq, k] where: + - Non-local experts are zeroed out + - Seq: sequence dimension (may be sharded) + - k: expert capacity + group (ProcessGroup): Model parallel group (default:) + + Returns: + Tensor: Sequence-wise partitioned weights [Seq/n, k] via reduce-scatter + """ + + ctx.mask = input == 0.0 + ctx.group = group + return reduce_scatter_group(input, group=group) + + @staticmethod + def backward(ctx, grad): + """Reconstructs expert-partitioned gradients from sequence-wise gradients. + + Args: + grad (Tensor): Sequence-wise partitioned gradients [Seq/n, k] + + Returns: + Tensor: Expert-wise partitioned gradients [Seq, k] with zeros for + non-local experts + """ + gathered = all_gather_group(grad, group=ctx.group) + return gathered.masked_fill( + ctx.mask, + 0.0, + ) + + +class MOEAllGatherLayerV2(MOEAlltoAllLayer): + """ + MoE Layer with allgather implement. + """ + + def __init__( + self, + gate: nn.Layer, + experts: List[nn.Layer], + layer_idx, + shared_experts: Optional[List[nn.Layer]] = None, + group: Group = None, + recompute=False, + k=2, + enable_reverse_token_drop=False, + all_to_all_dropout=0, + group_experts=False, + use_expert_out_alltoall=True, # + use_padding=True, + dense_token_type=3, # considerd as dense tokens (no moe) + moe_statics=None, + moe_num_experts=None, + ): + super().__init__( + gate, + experts, + layer_idx, + shared_experts, + group, + recompute, + k, + all_to_all_dropout, + group_experts, + moe_statics, + moe_num_experts, + ) + self.enable_reverse_token_drop = enable_reverse_token_drop + self.use_padding = use_padding + + # 全局 gate gather + self.send_rank = None + self.local_expert_id = None + self.dense_token_type = dense_token_type + self.capacity_tensor = None + self.use_expert_out_alltoall = use_expert_out_alltoall + logger.info( + f"uisng MOEAllGatherLayerV2, use_expert_out_alltoall={use_expert_out_alltoall}, " # false + f"use_padding={use_padding}, enable_reverse_token_drop={self.enable_reverse_token_drop}" # true false + ) + self.zero = paddle.to_tensor(0, dtype=paddle.float32) + + def forward( + self, + input: paddle.Tensor, + token_type_ids=None, + use_dense_expert=False, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Implements forward pass for Mixture-of-Experts (MoE) layer with distributed communication. + + Core Functionality: + - Processes input through gating network to determine expert assignments + - Performs distributed All-to-All communication for expert computation + - Combines expert outputs and calculates routing loss + + Key Features: + 1. Supports both dense and sparse expert computation modes + 2. Implements fused gating and dispatch for performance optimization + 3. Handles sequence length padding/unpadding for irregular inputs + 4. Enables communication-computation overlap through asynchronous operations + + Args: + input (Tensor): Input tensor of shape [seq_len, hidden_dim] + token_type_ids: Optional segmentation markers for heterogeneous inputs + use_dense_expert: Flag to enable dense expert computation bypass + + Returns: + tuple: ( + combined_output: Aggregated expert outputs [seq_len, hidden_dim], + combine_weights: Expert combination coefficients, + router_loss: Calculated router balancing loss + ) + """ + if input.ndim == 3: + orig_shape = input.shape + input = input.reshape([-1, input.shape[-1]]) + else: + orig_shape = None + + assert len(input.shape) == 2, f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}" + dispatch_token_type_ids = None + global_dense_expert_mask = None + if token_type_ids is not None: + token_type_ids = token_type_ids[:, :-1].reshape([-1]) + dispatch_token_type_ids = token_type_ids + if self.config.sequence_parallel: + hcg = fleet.get_hybrid_communicate_group() + rank = hcg.get_model_parallel_rank() + interval = token_type_ids.shape[0] // hcg.get_model_parallel_world_size() + token_type_ids = token_type_ids.slice([0], rank * interval, (rank + 1) * interval) + token_type_ids.stop_gradient = True + + if use_dense_expert: + global_dense_expert_mask = dispatch_token_type_ids == self.dense_token_type + + assert self.gate is not None + if hasattr(self, "rng") and self.rng.random() < self.all_to_all_dropout: + orig_shape_2 = input.shape + output = self.forward_experts(input) + output += self.gate.weight.sum() * 0.0 # hack for grad + output = output.reshape(orig_shape or orig_shape_2) # [e*1,c,m] + return output, None, 0 + ( + dispatched_input, + global_hidden_states, + local_combine_weights, + expert_num_global_no_token_drop, + expert_num_global, + expert_num_global_list, + local_scatter_index, + scatter_index_rev, + router_loss, + (gate_logits, gate_prob), + (gate_logits_mm, gate_prob_mm), + expert_num_local, + ) = self.fused_gate_and_dispatch(input, token_type_ids, global_dense_expert_mask) + seqlen_this_mp = input.shape[0] + if len(scatter_index_rev): + recv_rank_local = scatter_index_rev // seqlen_this_mp + else: + recv_rank_local = scatter_index_rev + + if self.use_padding: + if self.send_rank is None: + capacity = self.gate.get_capacity(input.shape[0] * self.config.moe_world_size) + self.send_rank = ( + paddle.arange(self.config.moe_world_size) + .repeat_interleave(capacity * self.num_local_experts) + .astype("int32") # cap + ) + self.local_expert_id = ( + paddle.arange(self.num_local_experts) + .repeat_interleave(capacity) + .tile(self.config.moe_world_size) + .astype(self.send_rank.dtype) + ) + recv_rank, recv_rank_task = allgather_async(recv_rank_local, group=self.config.moe_group) + send_rank = self.send_rank + local_expert_id = self.local_expert_id + + else: + all_expert_num = sum(expert_num_global_list) + # 非常慢 + if self.config.moe_group.nranks > 1: + recv_rank = paddle.empty([all_expert_num], dtype=recv_rank_local.dtype) + # 非常慢 + recv_rank_task = dist.stream.alltoall_single( + recv_rank, + recv_rank_local.tile(self.config.moe_world_size), + [ + sum(expert_num_global_list[i * self.num_local_experts : (i + 1) * self.num_local_experts]) + for i in range(self.config.moe_world_size) + ], # output-size + [len(recv_rank_local)] * self.config.moe_world_size, # input-size + group=self.config.moe_group, + sync_op=False, + use_calc_stream=False, + ) + else: + recv_rank_task = None + recv_rank = recv_rank_local.tile(self.config.moe_world_size) + + send_rank, local_expert_id = build_src_rank_and_local_expert_id( + expert_num_global, expert_num_global_list, self.num_local_experts + ) + + if not self.use_expert_out_alltoall: + expert_outs = ( + recompute(self.forward_experts, *dispatched_input) + if self.recompute and self.training + else self.forward_experts(*dispatched_input) + ) + expert_outs = paddle.concat([e for e in expert_outs if e is not None], axis=0) # [e*c,m] + expert_out_to_combine = AllGatherGroupOp.apply(expert_outs, group=self.config.moe_group) # for test + router_loss2 = self.calc_router_loss_and_logging( + router_loss, + gate_logits, + gate_prob, + gate_logits_mm, + gate_prob_mm, + local_combine_weights, + expert_num_global_no_token_drop, + token_type_ids, + dispatch_token_type_ids, + ) + else: + recv_rank_task and recv_rank_task.wait() # wait for recv_rank + + world_size = dist.get_world_size(self.config.moe_group) + this_rank = dist.get_rank(self.config.moe_group) + + recv_size = paddle.count_nonzero(recv_rank == dist.get_rank(self.config.moe_group)) + recv_size = paddle.maximum(recv_size, paddle.ones([], dtype=recv_size.dtype)) + + recv_size_cpu, recv_size_task = async_offload(recv_size, get_async_loader()) + + send_rank_this_rank = paddle.count_nonzero(send_rank == this_rank) + + send_rank_this_rank_cpu, send_rank_this_rank_task = async_offload(send_rank_this_rank, get_async_loader()) + + recv_rank[recv_rank == -1] = world_size + send_recv_count_global = paddle.scatter_nd_add( + paddle.zeros( + [self.num_local_experts, world_size + 1, world_size + 1], + dtype="int32", + ), + paddle.stack([local_expert_id, send_rank, recv_rank], -1), + paddle.ones([len(send_rank)], dtype="int32"), + ) # [num_local_experts, world_size + 1 , world_size + 1] + send_counts_cpu = send_recv_count_global[:, this_rank, :-1].numpy() + recv_counts_cpu = send_recv_count_global[:, :-1, this_rank].numpy() + send_counts_num_cpu = send_counts_cpu.sum(-1) + recv_counts_num_cpu = recv_counts_cpu.sum(-1) + + dispatched_input = self.forward_experts(*dispatched_input) + + if recv_size_task is not None: + recv_size_task.cpu_wait() + if send_rank_this_rank_task is not None: + send_rank_this_rank_task.cpu_wait() + + input_size = sum([len(i) if i is not None else 0 for i in dispatched_input]) + if self.use_padding or input_size > 1: + assert send_rank_this_rank_cpu.item() == input_size, ( + send_rank, + [len(i) if i is not None else 0 for i in dispatched_input], + ) + + expert_out_to_combine, router_loss2, distributed_input_to_alltoall_out = AlltoAllSmart.apply( + *dispatched_input, + router_loss, + gate_logits, + gate_prob, + gate_logits_mm, + gate_prob_mm, + local_combine_weights, + expert_num_global_no_token_drop, + token_type_ids, + dispatch_token_type_ids, + forward_func_dict=None, + router_loss_fn=self.calc_router_loss_and_logging, + local_expert_id=local_expert_id, + send_rank_global=send_rank, + recv_rank_global=recv_rank, + num_local_experts=self.num_local_experts, + capacity=dispatched_input[0].shape[1] if self.use_padding else None, + use_padding=self.use_padding, + expert_num_global=expert_num_global_list, + is_first_fwd=not framework._dygraph_tracer()._has_grad, + group=self.config.moe_group, + recv_size=recv_size_cpu, + send_counts=send_counts_cpu, + recv_counts=recv_counts_cpu, + send_counts_num=send_counts_num_cpu, + recv_counts_num=recv_counts_num_cpu, + ) + # /origin input -> distributed input/ => /origin-input -> alltoall out -input/ + local_scatter_index = distributed_input_to_alltoall_out[local_scatter_index] + local_scatter_index.stop_gradient = True + # global -> local + combined_output = self.combine_expert_output(expert_out_to_combine, local_combine_weights, local_scatter_index) + + if self.shared_experts is not None: + shared_out = self.shared_experts(input) + combined_output += shared_out + + if orig_shape: + combined_output = combined_output.reshape(orig_shape[:-1] + [combined_output.shape[-1]]) + + return combined_output, local_combine_weights, router_loss2, gate_logits + + def fused_gate_logits_process_fused(self, gate_logits_lm, gate_logits_mm=None, token_type_ids=None): + """Process gating logits for expert selection in Mixture-of-Experts (MoE) layers. + + Core Functionality: + - Transforms raw gating logits into expert selection weights and IDs + - Supports both grouped and standard expert selection modes + - Handles bias correction for improved expert load balancing + + Args: + gate_logits_lm (Tensor): Raw gating scores of shape [batch_size, total_experts] + + Returns: + tuple: ( + lm_weight_and_expert_id: Combined tensor containing selection weights + and expert IDs [batch_size, 2*top_k], + prob_flat: Flattened expert probabilities [batch_size, total_experts] + ) + """ + top_k = self.k + num_expert_per_rank_per_modality = gate_logits_lm.shape[-1] // self.config.moe_world_size + group_size = gate_logits_lm.shape[-1] // top_k + if self.group_experts: + assert not self.use_correction_bias + gate_logits_lm = gate_logits_lm.reshape([gate_logits_lm.shape[0], top_k, -1]) + prob_lm = self.gate.act(gate_logits_lm) + prob_lm_ = prob_lm + weight_lm, expert_id_lm = prob_lm_.topk(k=1, axis=-1) + weight_lm = weight_lm.reshape([gate_logits_lm.shape[0], -1]) + group_size = gate_logits_lm.shape[-1] + expert_id_lm = expert_id_lm.squeeze(-1) + else: + prob_lm = self.gate.act(gate_logits_lm) + if self.use_correction_bias: + prob_lm_ = prob_lm + self.moe_statics.e_score_correction_bias[0].detach() + else: + prob_lm_ = prob_lm + weight_lm, expert_id_lm = prob_lm_.topk(k=top_k, axis=-1) + + if self.use_correction_bias: + batch_idx = paddle.arange(prob_lm_.shape[0]).unsqueeze(-1).expand_as(expert_id_lm) + weight_lm = prob_lm[batch_idx, expert_id_lm] # use correct bias + + expert_id_lm = expand_modality_expert_id( + expert_id_lm, + num_expert_per_modality=(num_expert_per_rank_per_modality if token_type_ids is not None else 0), + group_size=group_size, + modality_offset=0, + is_group_expert=self.group_experts, + ) + expert_id_lm = expert_id_lm.reshape(weight_lm.shape) + lm_weight_and_expert_id = paddle.concat([weight_lm, expert_id_lm.astype("float32")], -1) + + if token_type_ids is None or gate_logits_mm is None: + return ( + lm_weight_and_expert_id, + prob_lm.reshape([prob_lm.shape[0], -1]), + None, + ) + + prob_mm = self.gate.act(gate_logits_mm) + if self.use_correction_bias: + prob_mm_ = prob_mm + self.moe_statics.e_score_correction_bias[1].detach() + else: + prob_mm_ = prob_mm + weight_mm, expert_id_mm = prob_mm_.topk(k=top_k, axis=-1) + if self.use_correction_bias: + batch_idx = paddle.arange(prob_lm_.shape[0]).unsqueeze(-1).expand_as(expert_id_lm) + weight_mm = prob_mm[batch_idx, expert_id_mm] # use correct bias + + expert_id_mm = expand_modality_expert_id( + expert_id_mm, + num_expert_per_modality=num_expert_per_rank_per_modality, + group_size=group_size, + modality_offset=1, + is_group_expert=False, + ) + expert_id_mm = expert_id_mm.reshape(weight_mm.shape) + mm_weight_and_expert_id = paddle.concat([weight_mm, expert_id_mm.astype("float32")], -1) + weight_and_expert = paddle.where( + (token_type_ids == 0).unsqueeze(-1), + lm_weight_and_expert_id, + mm_weight_and_expert_id, + ) + return weight_and_expert, prob_lm.reshape([prob_lm.shape[0], -1]), prob_mm + + def fused_gate_and_dispatch(self, input, token_type_ids=None, global_dense_expert_mask=None): + """Implements fused expert gating and token dispatch logic for Mixture-of-Experts (MoE) layers. + + Core Functionality: + - Computes expert selection probabilities and routing weights + - Performs distributed token-to-expert assignment + - Handles communication and synchronization in model-parallel environments + + Args: + input (Tensor): Input tensor of shape [seq_len, hidden_dim] + + Returns: + tuple: ( + dispatched_input: Expert-assigned tokens [num_experts, capacity, hidden_dim], + global_hidden_states: Full sequence representations, + local_combine_weights: Local expert combination weights, + expert_num_global_notrunc: Global expert token counts (without capacity truncation), + expert_num_global: Actual expert token counts, + expert_num_global_list: Per-expert token counts, + local_scatter_index: Local token reorganization indices, + scatter_index_rev: Reverse scattering indices, + router_loss: Calculated routing loss, + gate_outputs: Raw gating network outputs, + expert_num_local: Local expert utilization counts + ) + """ + seqlen, d_model = input.shape + args = () + if token_type_ids is not None: + token_type_ids = token_type_ids.reshape([-1]) + args = (token_type_ids,) + + router_loss = paddle.zeros([1], dtype="float32") + router_loss.stop_gradient = False + top_k = self.k + + def build_weights_and_expert_id(input): + nonlocal token_type_ids, args + logits, capacity, router_loss = self.gate(input, *args, transform_weight=False) + if self.config.multimodel_experts: + gate_logits_lm, gate_logits_mm = logits.chunk(2, axis=-1) + else: + gate_logits_lm, gate_logits_mm = logits, None + + weigth_and_expert, gate_prob_lm, gate_prob_mm = self.fused_gate_logits_process_fused( + gate_logits_lm, + gate_logits_mm, + token_type_ids if global_dense_expert_mask is None else None, + ) + weigth_and_expert = AllGatherGroupOp.apply(weigth_and_expert, group=self.config.moe_group) + return ( + weigth_and_expert, + gate_logits_lm, + gate_logits_mm, + gate_prob_lm, + gate_prob_mm, + ) + + capacity = self.gate.get_capacity(input.shape[0]) * self.world_size + ( + global_hidden_states, + combine_weights_and_expert_id, + gate_logits_lm, + gate_logits_mm, + gate_prob_lm, + gate_prob_mm, + ) = AllGatherAsync.apply( + input, + input, + fn=build_weights_and_expert_id, + group=self.config.moe_group, + is_first_fwd=not framework._dygraph_tracer()._has_grad, + ) + combine_weights_unnorm, expert_id = combine_weights_and_expert_id.chunk(2, axis=-1) + expert_id = expert_id.cast("int32") + expert_id.stop_gradient = True + num_experts = ( + sum(self.config.moe_num_experts) + if isinstance(self.config.moe_num_experts, (tuple, list)) + else self.config.moe_num_experts + ) # all-experts = 96 + if global_dense_expert_mask is not None: + combine_weights_unnorm[global_dense_expert_mask] = 0.0 + expert_id[global_dense_expert_mask] = num_experts + num_experts += 1 + + if "reverse_token_drop" in inspect.signature(moe_gate_dispatch_partial_nosoftmaxtopk).parameters: + compat_kwargs = {"reverse_token_drop": self.enable_reverse_token_drop} + else: + compat_kwargs = {} + + # Disable AMP because: + # - combine_weights_unnorm is fp32, global_hidden_states is bf16 + # - AMP O2 would upcast global_hidden_states to fp32, making dispatched_input fp32 + # - This is a data movement op with no computation, so upcasting is unnecessary + with paddle.amp.auto_cast(False): + ( + dispatched_input, + combine_weights_unnorm, + scatter_index, # input -> dispatched_input + scatter_index_rev, # dispatch-input -> input + expert_num_global, + expert_num_local, + ) = moe_gate_dispatch_partial_nosoftmaxtopk( + global_hidden_states, + combine_weights_unnorm, + expert_id, + top_k, + capacity, + num_experts, + self.use_padding, + expert_start_index=self.num_local_experts * self.config.moe_rank, + expert_end_index=self.num_local_experts * (self.config.moe_rank + 1), + **compat_kwargs, + ) + + if self.use_correction_bias: + if self.gate.config.multimodel_experts: + # MLLM + for i in range(len(self.moe_statics.expert_usage)): + self.moe_statics.expert_usage[i] += expert_num_local[self.gate.experts_type_mask[i]].detach() + else: + # LLM + self.moe_statics.expert_usage[0] += expert_num_local.detach() + + # When use unpad , `moe_ops_partial` output likes `scatter_index_rev==[]`. + if scatter_index_rev.ndim == 0: + assert not self.use_padding + scatter_index_rev = paddle.empty([0], dtype=scatter_index_rev.dtype) + + dispatched_input.stop_gradient = False + combine_weights_unnorm.stop_gradient = False + scatter_index.stop_gradient = True + expert_num_global.stop_gradient = True + expert_num_global_notrunc = expert_num_global + self.capacity_tensor = paddle.to_tensor(capacity, dtype=expert_num_global.dtype) + expert_num_global = paddle.minimum(expert_num_global, self.capacity_tensor) + + if global_dense_expert_mask is not None: + expert_num_global = expert_num_global[:-1] + expert_num_local = expert_num_local[:-1] + expert_num_global_notrunc = expert_num_global_notrunc[:-1] + + scatter_index = scatter_index.transpose([1, 0]) # [k,s] ->[s,k] + + last_local_expert = self.num_local_experts * self.config.moe_rank + expert_offset_global = expert_num_global.cumsum() + + loader = get_async_loader() + expert_num_global_list, offload_task = async_offload(expert_num_global, loader) + if self.use_padding: + offset = last_local_expert * capacity + else: + offset = expert_offset_global[last_local_expert - 1] if self.config.moe_rank > 0 else 0 + local_combine_weights_unnorm = ReshardCombineWeight.apply( + combine_weights_unnorm.contiguous(), group=self.config.moe_group + ) + local_scatter_index = ReduceScatterGroupOp.apply( + paddle.where( + combine_weights_unnorm > 0.0, + scatter_index + offset, + scatter_index, + ), + group=self.config.moe_group, + ) + if self.gate.norm_gate_logits: + local_combine_weights = local_combine_weights_unnorm / paddle.clip( + local_combine_weights_unnorm.sum(-1, keepdim=True), min=1e-12 + ) + else: + local_combine_weights = local_combine_weights_unnorm + local_combine_weights = local_combine_weights.cast(dispatched_input.dtype) + if self.use_padding: + dispatched_input = dispatched_input.reshape([self.num_local_experts, -1, d_model]) + dispatched_input = dispatched_input.unbind(0) + else: + s = self.num_local_experts * self.config.moe_rank + e = self.num_local_experts * (self.config.moe_rank + 1) + expert_num_local = expert_num_local.tolist()[s:e] + expert_num_local_valid = [i for i in expert_num_local if i > 0] + valid_pos = [j for j, i in enumerate(expert_num_local) if i > 0] + if expert_num_local_valid: + dispatched_input_list = dispatched_input.split(expert_num_local_valid) + dispatched_input = [None] * len(expert_num_local) + for p, t in zip(valid_pos, dispatched_input_list): + dispatched_input[p] = t + else: + dispatched_input = [dispatched_input] + ([None] * (len(expert_num_local) - 1)) + + scatter_index.stop_gradient = True + scatter_index_rev.stop_gradient = True + if offload_task is not None: + hack_offload_wait(offload_task) + expert_num_global_list = expert_num_global_list.tolist() + + return ( + dispatched_input, + global_hidden_states, + local_combine_weights, + expert_num_global_notrunc, # for auxloss calculation. + expert_num_global, + expert_num_global_list, + local_scatter_index, + scatter_index_rev, + router_loss, + (gate_logits_lm, gate_prob_lm), + (gate_logits_mm, gate_prob_mm), + expert_num_local, + ) + + def forward_experts(self, *dispatched_input): + """Execute expert model computations in sequence for Mixture-of-Experts (MoE) layer. + + Core Functionality: + - Distributes dispatched tokens to local expert models + - Handles empty expert inputs with zero-initialized fallback + - Maintains gradient flow for expert outputs + - Aggregates outputs from all active experts + + Args: + *dispatched_input: Variable-length expert-specific input tensors + + Returns: + list: Expert output tensors (None for inactive experts) + + Implementation Details: + 1. Processes valid expert inputs through corresponding expert models + 2. Generates dummy inputs for inactive experts to preserve model structure + 3. Aggregates dummy outputs to first active expert to maintain gradient flow + """ + expert_outputs = [] + assert isinstance(self.experts, nn.LayerList), type(self.experts) + + no_tokens_expert_outputs = [] + if not self.multimodal_experts: + true_experts = self.experts[self.rank * self.num_local_experts : (self.rank + 1) * self.num_local_experts] + else: + true_experts = [] + for i, num in enumerate(self.num_local_multimodal_experts): + current_modal_experts = self.experts[ + self.multimodal_expert_index[i] : self.multimodal_expert_index[i + 1] + ] + true_experts.extend(current_modal_experts[self.rank * num : (self.rank + 1) * num]) + + assert len(dispatched_input) == len(true_experts), ( + len(dispatched_input), + len(true_experts), + ) + + for iexpert, chunk in enumerate(dispatched_input): + if chunk is None: + # QuantizationLoRALinear can not call `.weight`. + if not isinstance(true_experts[iexpert].up_gate_proj, QuantizationLoRALinear): + input_shape = [ + 1, + true_experts[iexpert].up_gate_proj.weight.shape[0], + ] + input_dtype = true_experts[iexpert].up_gate_proj.weight.dtype + else: + input_shape = [ + 1, + true_experts[iexpert].up_gate_proj.lora_A.shape[0], + ] + input_dtype = true_experts[iexpert].up_gate_proj.lora_A.dtype + + chunk = paddle.zeros( + input_shape, + input_dtype, + ) + if true_experts[iexpert].training: + chunk.stop_gradient = False + expert_out = true_experts[iexpert](chunk.contiguous()) + no_tokens_expert_outputs.append(expert_out * 0.0) # mutiply 0.0 to zero out and grad + + expert_outputs.append(None) + continue + + expert_out = true_experts[iexpert](chunk.contiguous()) + expert_outputs.append(expert_out) + + # if self.config.moe_layer_feed_fake_token and len(no_tokens_expert_outputs) > 0: + if len(no_tokens_expert_outputs) > 0: + first_has_tokens_idx = 0 + for idx, expert_out in enumerate(expert_outputs): + if expert_out is not None: + first_has_tokens_idx = idx + break + for idx, expert_out in enumerate(no_tokens_expert_outputs): + expert_outputs[first_has_tokens_idx] += expert_out + + return expert_outputs + + def calc_router_loss_and_logging( + self, + router_loss, + gate_logits, + gate_prob, + gate_logits_mm, + gate_prob_mm, + combine_weights, + dispatch_mask, + token_type_ids, + dispatch_token_type_ids, + ): + """Calculate and aggregate router auxiliary loss for Mixture-of-Experts training. + + Core Functionality: + - Computes expert load balancing loss to prevent expert under-utilization + - Integrates multiple loss components from different routing stages + - Maintains gradient flow for routing mechanism optimization + + Args: + router_loss (Tensor): Accumulated router loss tensor + gate_logits (Tensor): Raw gating network outputs [batch_size, num_experts] + gate_prob (Tensor): Activated gating probabilities [batch_size, num_experts] + combine_weights (Tensor): Expert combination weights [batch_size, top_k] + dispatch_mask (Tensor): Token dispatch mask indicating expert assignments + + Returns: + Tensor: Updated router loss with new auxiliary components + """ + dispatch_mask_3d = dispatch_mask.reshape([self.config.moe_world_size, -1]) + if token_type_ids is not None and self.gate.config.moe_use_hard_gate: + # MLLM + if not self.gate.weight.stop_gradient: + dispatch_tokens_mask = dispatch_token_type_ids == 0 if dispatch_token_type_ids is not None else None + lm_tokens_mask = (token_type_ids == 0).astype(gate_prob.dtype) + # hard code + lm_experts = ( + self.gate.num_experts[0] + if isinstance(self.gate.num_experts, (tuple, list)) + else self.gate.num_experts + ) + dispatch_mask_lm = dispatch_mask_3d[:, : lm_experts // self.config.moe_world_size].reshape([-1]) + router_loss += self._calc_router_loss( + dispatch_mask_lm, + gate_logits * lm_tokens_mask.unsqueeze(-1), + gate_prob * lm_tokens_mask.unsqueeze(-1), + self.gate.num_experts_list[0], + self.group_experts, + self.layer_idx, + 0, # ortholoss + lm_tokens_mask, + dispatch_tokens_mask, + prefix="lm", + ) + else: + router_loss += self.zero * gate_logits[0, 0] * gate_prob[0, 0] + if gate_prob_mm is not None: + mm_tokens_mask = (token_type_ids == 1).astype(gate_prob_mm.dtype) + dispatch_tokens_mask = dispatch_token_type_ids == 1 if dispatch_token_type_ids is not None else None + dispatch_mask_mm = dispatch_mask_3d[ + :, self.gate.num_experts[0] // self.config.moe_world_size : + ].reshape([-1]) + + router_loss += self._calc_router_loss( + dispatch_mask_mm, + gate_logits_mm * mm_tokens_mask.unsqueeze(-1), + gate_prob_mm * mm_tokens_mask.unsqueeze(-1), + self.gate.num_experts_list[1], + False, + self.layer_idx, + 1, + mm_tokens_mask, + dispatch_tokens_mask, + prefix="mm", + ) + + else: + # LLM + router_loss += self._calc_router_loss( + dispatch_mask, + gate_logits, + gate_prob, + self.gate.num_experts_tensor, + self.group_experts, + self.layer_idx, + 0, + paddle.ones([gate_prob.shape[0]], "bool"), + paddle.ones([self.gate.config.moe_world_size * gate_prob.shape[0]], "bool"), + prefix="lm", + ) + + return router_loss diff --git a/paddleformers/nn/moe/moe_alltoall_layer.py b/paddleformers/nn/moe/moe_alltoall_layer.py new file mode 100644 index 00000000000..876fcabbfe0 --- /dev/null +++ b/paddleformers/nn/moe/moe_alltoall_layer.py @@ -0,0 +1,693 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""_summary_ + +Returns: + _type_: _description_ +""" + +import inspect +import itertools +from collections import namedtuple +from typing import List, Optional, Tuple + +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import Tensor, _C_ops, framework, nn +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.distributed.communication import stream +from paddle.distributed.communication.group import Group +from paddle.distributed.fleet.utils import recompute +from paddle.incubate.nn.functional import moe_combine, moe_gate_dispatch + +from paddleformers.utils.log import logger + +from .abstract import MOELayerBase +from .utils import ScatterOp + + +class GateCombine(PyLayer): + """ + Custom PyLayer for gate combination operations with backward pass. + """ + + @staticmethod + def forward(ctx, x, combine_weights, scatter_index): + """ + Forward pass for gate combination. + + Args: + x: Input tensor + combine_weights: Combination weights + scatter_index: Scatter indices + + Returns: + Tensor: Combined output + """ + ctx.x = x + ctx.combine_weights = combine_weights + ctx.scatter_index = scatter_index + ret = moe_combine(x, combine_weights, scatter_index) + return ret + + @staticmethod + def backward(ctx, grad_y, *_): + """ + Backward pass for gate combination. + + Args: + grad_y: Gradient of output [seqlen, hidden_size] + + Returns: + tuple: (grad_x, grad_combine_weight, None) + """ + grad_x, grad_combine_weight_helper = _C_ops.moe_combine_grad( + ctx.x, ctx.combine_weights, ctx.scatter_index, grad_y + ) + # grad_combine_weight_helper is the same shape with grad x [seqlen * K, dim] + # reduce the hidden shape + # TODO: implement reduce in cuda ops + grad_combine_weight = grad_combine_weight_helper.sum(-1) + return grad_x, grad_combine_weight.reshape(ctx.combine_weights.shape), None + + +def combining(x, combine_weights, scatter_index, hard_gate=False): + """ + Fused version of combining operation. + + Args: + x: Input tensor [seq, dim] + combine_weights: Combination weights [s, k] + scatter_index: Scatter indices [k, s] + hard_gate: Whether to use hard gating + + Returns: + Tensor: Combined output [s, dim] + """ + if hard_gate: + x_gatherd = F.embedding(scatter_index, x) # [s,k,dim] + return x_gatherd.squeeze(-2) + if paddle.device.is_compiled_with_custom_device("npu"): + from ernie.fusion_ops.npu_fusion_ops import npu_combining + + ret = npu_combining(x, combine_weights, scatter_index) + else: + ret = GateCombine.apply(x, combine_weights, scatter_index) + ret.stop_gradient = False + return ret + + +class MOEAlltoAllLayer(MOELayerBase): + """ + Mixture of Experts layer implementation based on GShard paper. + """ + + def __init__( + self, + gate: nn.Layer, + experts: List[nn.Layer], + layer_idx, + shared_experts: Optional[List[nn.Layer]] = None, + group: Group = None, + recompute=False, + k=2, + all_to_all_dropout=0, + group_experts=False, + moe_statics=None, + moe_num_experts=None, + ): + """ + Initialize MoE layer. + + Args: + gate: Gate network for expert selection + experts: List of expert networks + layer_idx: Index of this layer in the model + group: Distributed communication group + recompute: Whether to enable recomputation + k: Number of experts to select per token + all_to_all_dropout: Dropout rate for all-to-all communication + group_experts: Whether to group experts + moe_statics: MoE statistics tracking object + """ + super().__init__() + self.gate = gate + self.layer_idx = layer_idx + self.recompute = recompute + for p in self.gate.parameters(): + p.is_gate = True + if isinstance(experts, nn.LayerList): + self.experts = experts + else: + logger.info(f"using fused experts, type={type(experts)}") + self.experts = experts + self.shared_experts = shared_experts + + self.group = group + self.k = k + self.all_to_all_dropout = all_to_all_dropout + self.use_correction_bias = moe_statics is not None + self.moe_statics = moe_statics + if self.use_correction_bias: + logger.info(f"using correction bias, aux-coef:{self.gate.config.moe_aux_loss_lambda}") + assert self.gate.config.moe_use_aux_free + + self.is_mp_moe = ( + hasattr(fleet.fleet, "_hcg") and group is fleet.get_hybrid_communicate_group().get_model_parallel_group() + ) + is_dummy_moe = dist.get_world_size(group) == 1 + + for p in experts.parameters(): + p.expert = not (self.is_mp_moe or is_dummy_moe) # type: ignore + p.no_sync = not (self.is_mp_moe or is_dummy_moe) + if self.is_mp_moe: + p.is_distributed = True + p.mp_moe = True + + self.world_size = dist.get_world_size(self.group) + # assert self.world_size > 1, f'moe-group not found, world_size {self.world_size}' + self.rank = dist.get_rank(self.group) + if self.world_size < 1: + self.world_size = 1 + if self.rank < 0: + self.rank = 0 + + self.multimodal_experts = isinstance(moe_num_experts, (tuple, list)) and len(moe_num_experts) > 1 + self.num_local_experts = len(self.experts) // self.world_size + if self.multimodal_experts: + self.num_local_multimodal_experts = [num // self.world_size for num in moe_num_experts] + self.multimodal_expert_index = [0] + list(itertools.accumulate(moe_num_experts)) + + self.input_preprocess = self.output_postprocess = None + self.group_experts = group_experts + self.config = self.gate.config + self.zero = paddle.to_tensor(0, dtype=paddle.float32) + + def forward_experts(self, dispatched_input): + """ + Forward pass through experts sequentially. + + Args: + dispatched_input: Input tensor of shape [num_experts, capacity, dim] + + Returns: + Tensor: Expert outputs of shape [num_experts, capacity, dim] + """ + + if not self.multimodal_experts: + true_experts = self.experts[self.rank * self.num_local_experts : (self.rank + 1) * self.num_local_experts] + else: + true_experts = [] + for i, num in enumerate(self.num_local_multimodal_experts): + current_modal_experts = self.experts[ + self.multimodal_expert_index[i] : self.multimodal_expert_index[i + 1] + ] + true_experts.extend(current_modal_experts[self.rank * num : (self.rank + 1) * num]) + + dispatched_input = dispatched_input.reshape( + [self.world_size, self.num_local_experts, -1, dispatched_input.shape[-1]] + ) # [e,1,c,m] + expert_outputs = [] + if isinstance(self.experts, nn.LayerList): + chunks = dispatched_input.transpose([1, 0, 2, 3]).contiguous().unbind(0) + assert len(chunks) == len(true_experts), (len(chunks), len(true_experts)) + for chunk, expert in zip(chunks, true_experts): + expert_outputs += [expert(chunk)] + else: + dispatched_input = dispatched_input.transpose([1, 0, 2, 3]) + dispatched_input.contiguous() + orig_shape = dispatched_input.shape + chunks = dispatched_input.reshape([orig_shape[0], -1, orig_shape[-1]]) + chunks = self.experts(chunks) + chunks = chunks.reshape(orig_shape[:-1] + [chunks.shape[-1]]).unbind(0) + expert_outputs += chunks + expert_output = paddle.stack(expert_outputs, axis=1) # [ecm] + return expert_output + + def fused_gate_logits_process(self, gate_logits, token_type_ids=None, offload_helper=None): + """ + Process and combine gate logits. + + Args: + gate_logits: Raw gate logits + + Returns: + tuple: (processed probabilities, max probabilities) + """ + k = self.k + experts_type_ids = self.gate.experts_type_ids + use_hard_gate = self.config.moe_use_hard_gate + max_prob = None + + if token_type_ids is not None and use_hard_gate: + if offload_helper is None: + offload_helper = dict() + lm_mask = token_type_ids == 0 + is_lm = lm_mask.any() + mm_mask = token_type_ids == 1 + is_mm = mm_mask.any() + seq_lm = lm_mask.sum() + seq_mm = mm_mask.sum() + lm_mask = lm_mask.unsqueeze(1) & (experts_type_ids == 0).unsqueeze(0) + mm_mask = mm_mask.unsqueeze(1) & (experts_type_ids == 1).unsqueeze(0) + offload_helper["lm_mask"] = [lm_mask, is_lm, seq_lm] + offload_helper["mm_mask"] = [mm_mask, is_mm, seq_mm] + + is_lm = offload_helper["lm_mask"][1] + prob = paddle.zeros_like(gate_logits) + # 处理 lm_prob + if is_lm: + lm_mask = offload_helper["lm_mask"][0] + seq_lm_cpu = offload_helper["lm_mask"][2] + lm_mask_nonzero = lm_mask.nonzero() + lm_partial_gate_logits = gate_logits.gather_nd(lm_mask_nonzero).reshape([seq_lm_cpu, -1]) + if self.group_experts: + lm_prob = self.gate.act(lm_partial_gate_logits.reshape([lm_partial_gate_logits.shape[0], k, -1])) + max_prob = lm_prob.max(-1, keepdim=True) # [s_l, k, 1] + lm_prob /= max_prob + else: + lm_prob = self.gate.act(lm_partial_gate_logits) + prob = paddle.scatter_nd_add(prob, lm_mask_nonzero, lm_prob.flatten()) + # 处理 mm_prob + is_mm = offload_helper["mm_mask"][1] + if is_mm: + mm_mask = offload_helper["mm_mask"][0] + seq_mm_cpu = offload_helper["mm_mask"][2] + mm_mask_nonzero = paddle.nonzero(mm_mask) + mm_partial_gate_logits = gate_logits.gather_nd(mm_mask_nonzero).reshape([seq_mm_cpu, -1]) + mm_prob = self.gate.act(mm_partial_gate_logits) + prob = paddle.scatter_nd_add(prob, mm_mask_nonzero, mm_prob.flatten()) + else: + # 处理非硬门和不需要token_type_ids的情况 + if self.group_experts: + prob = self.gate.act(gate_logits.reshape([gate_logits.shape[0], k, -1])) + max_prob = prob.max(-1, keepdim=True) + prob /= max_prob + prob = prob.reshape([prob.shape[0], -1]) + else: + prob = self.gate.act(gate_logits) + return prob, max_prob + + def gate_and_dispatch(self, input, token_type_ids=None): + """ + Calculate gate and dispatch inputs. + + Args: + input: Input tensor of shape [seq, dim] + + Returns: + tuple: (dispatched_input, combine_weights, dispatch_mask, + scatter_index, router_loss, gate_logits, gate_prob) + """ + seqlen, d_model = input.shape + args = () + if token_type_ids is not None: + token_type_ids = token_type_ids.reshape([-1]) + args = (token_type_ids,) + + ( + gate_logits, + capacity, + router_loss, + ) = self.gate(input, *args) + if self.input_preprocess is not None: + input, gate_logits = self.input_preprocess(input, gate_logits, capacity) + # capacity no use + k = self.k + prob, max_prob = self.fused_gate_logits_process(gate_logits, token_type_ids) + + if "corr_bias" in inspect.signature(moe_gate_dispatch).parameters: + if self.use_correction_bias: + compat_args = (self.moe_statics.e_score_correction_bias[0],) + else: + compat_args = (None,) + else: + assert not self.use_correction_bias, "correction bias not supported, rebuild moe-ops" + compat_args = () + + ( + dispatched_input, + combine_weights_unnorm, + scatter_index, + dispatch_mask, + _, + ) = moe_gate_dispatch(input, prob, *compat_args, k=k, capacity=capacity, use_pad=True) + dispatched_input = dispatched_input.astype(input.dtype) + + dispatch_mask = paddle.diff(F.pad(dispatch_mask, (1, 0))) + if self.use_correction_bias: + if self.gate.config.multimodel_experts: + for i in range(len(self.moe_statics.expert_usage)): + self.moe_statics.expert_usage[i] += dispatch_mask[self.gate.experts_type_mask[i]].detach() + else: + self.moe_statics.expert_usage[0] += dispatch_mask.detach() + dispatched_input.stop_gradient = False + combine_weights_unnorm.stop_gradient = False + scatter_index.stop_gradient = True + dispatch_mask.stop_gradient = True + + scatter_index = scatter_index.transpose([1, 0]) # [k,s] ->[s,k] + if self.group_experts: + if max_prob is not None: + if token_type_ids is not None: + p = paddle.ones_like(combine_weights_unnorm.unsqueeze(-1)) + p = paddle.scatter_nd_add(p, paddle.nonzero(token_type_ids == 0), -1 + max_prob) + else: + p = max_prob + combine_weights_unnorm = (combine_weights_unnorm.unsqueeze(-1) * p).squeeze(-1) + # gate_prob 进行还原 + prob = (prob.reshape([p.shape[0], k, -1]) * p).reshape([p.shape[0], -1]) + if self.gate.norm_gate_logits: + combine_weights = combine_weights_unnorm / paddle.clip( + combine_weights_unnorm.sum(-1, keepdim=True), min=1e-12 + ) + else: + combine_weights = combine_weights_unnorm + combine_weights = combine_weights.cast(dispatched_input.dtype) + + dispatched_input = dispatched_input.reshape([self.world_size * self.num_local_experts, capacity, d_model]) + dispatch_mask.stop_gradient = True + scatter_index.stop_gradient = True + return ( + dispatched_input, + combine_weights, + dispatch_mask, + scatter_index, + router_loss, + gate_logits, + prob, + ) + + def _calc_router_loss( + self, + dispatch_mask, + gate_logits, + gate_prob, + num_experts, + use_group, + layer_idx, + token_type=None, + tokens_type_mask=None, + dispatch_tokens_mask=None, + prefix="", + ): + """ + Calculate router loss including auxiliary loss, z-loss and orthogonal loss. + + Args: + dispatch_mask: Dispatch mask + gate_logits: Gate logits + gate_prob: Gate probabilities + num_experts: Number of experts + use_group: Whether to use expert groups + layer_idx: Layer index + token_type: Token type + tokens_type_mask: Token type mask + dispatch_tokens_mask: Dispatch tokens mask + prefix: Prefix for logging + + Returns: + Tensor: Total router loss + """ + router_loss, l_aux, orthogonal_loss, zloss = 0.0, None, None, None + if self.gate.config.moe_aux_loss_lambda: + l_aux = self.gate._cal_aux_loss( + gate_prob, + dispatch_mask, + num_experts, + use_group, + tokens_type_mask, + dispatch_tokens_mask, + ) + router_loss += self.gate.moe_aux_loss_lambda[token_type or 0] * l_aux + else: + router_loss += self.zero * gate_prob[0, 0] # must use gate prob to avoid zero pointer + if self.gate.config.moe_orthogonal_loss_lambda: + orthogonal_loss = self.gate._cal_orthogonal_loss(token_type, use_group) + router_loss += self.gate.moe_orthogonal_loss_lambda[token_type or 0] * orthogonal_loss + if self.gate.config.moe_z_loss_lambda: + zloss = self.gate._cal_z_loss(gate_logits, tokens_type_mask) + router_loss += self.gate.moe_z_loss_lambda[token_type or 0] * zloss + return router_loss + + def calc_router_loss_and_logging( + self, + router_loss, + combine_weights, + dispatch_mask, + gate_logits, + gate_prob, + token_type_ids=None, + dispatch_token_type_ids=None, + offload_helper=None, + ): + """ + Calculate auxiliary losses and log statistics in fused expert case. + + Args: + router_loss: Base router loss + combine_weights: Combination weights + dispatch_mask: Dispatch mask + gate_logits: Gate logits + gate_prob: Gate probabilities + + Returns: + Tensor: Updated router loss + """ + assert gate_prob is not None + if token_type_ids is not None and self.gate.config.moe_use_hard_gate: # true + if not self.gate.weight.stop_gradient: + lm_tokens_mask = token_type_ids == 0 + if offload_helper is not None: + is_lm = offload_helper["lm_mask"][1] + else: + is_lm = lm_tokens_mask.any() + if is_lm: + dispatch_tokens_mask = ( + dispatch_token_type_ids == 0 if dispatch_token_type_ids is not None else None + ) + router_loss += self._calc_router_loss( + ( + dispatch_mask[self.gate.experts_type_mask[0]] + if hasattr(self.gate, "experts_type_mask") + else dispatch_mask + ), + ( + gate_logits[:, self.gate.experts_type_mask[0]] + if hasattr(self.gate, "experts_type_mask") + else gate_logits + ), + ( + gate_prob[:, self.gate.experts_type_mask[0]] + if hasattr(self.gate, "experts_type_mask") + else gate_prob + ), + ( + self.gate.num_experts_list[0] + if hasattr(self.gate, "num_experts_list") + else self.gate.num_experts_tensor + ), + self.group_experts, + self.layer_idx, + 0, + lm_tokens_mask, + dispatch_tokens_mask, + prefix="lm", + ) + mm_tokens_mask = token_type_ids == 1 + if offload_helper is not None: + is_mm = offload_helper["mm_mask"][1] + else: + is_mm = mm_tokens_mask.any() + if is_mm: + dispatch_tokens_mask = dispatch_token_type_ids == 1 if dispatch_token_type_ids is not None else None + router_loss += self._calc_router_loss( + dispatch_mask[self.gate.experts_type_mask[1]], + gate_logits[:, self.gate.experts_type_mask[1]], + gate_prob[:, self.gate.experts_type_mask[1]], + self.gate.num_experts_list[1], + False, + self.layer_idx, + 1, + mm_tokens_mask, + dispatch_tokens_mask, + prefix="mm", + ) + + else: + router_loss += self._calc_router_loss( + dispatch_mask, + gate_logits, + gate_prob, + self.gate.num_experts_tensor, + self.group_experts, + self.layer_idx, + ) + + return router_loss + + def combine_expert_output(self, expert_output, combine_weights, scatter_index): + """ + Combine expert outputs using combination weights. + + Args: + expert_output: Expert outputs [num_experts, capacity, dim] + combine_weights: Combination weights + scatter_index: Scatter indices + + Returns: + Tensor: Combined output [seqlen, dim] + """ + expert_output = expert_output.reshape([-1, expert_output.shape[-1]]) # [e*1,c,m] + combined_output = combining(expert_output, combine_weights, scatter_index) + + if self.output_postprocess is not None: + combined_output = self.output_postprocess(combined_output) + + return combined_output + + def forward_single_stage(self, dispatched_input, stage_id): + """ + Forward pass for single expert stage. + + Args: + dispatched_input: Dispatched input + stage_id: Stage index + + Returns: + Tensor: Expert output + """ + assert isinstance(self.experts, nn.LayerList) + return self.experts[stage_id](dispatched_input) + + def all2all_expert_overlap(self, x, group): + """all2all_expert_overlap""" + all2all_tasks = [] + all2all_ins = paddle.unbind(x, axis=0) + for stage_id in range(1): + stage_input = all2all_ins[stage_id] + x_out, task = AlltoAll.apply(stage_input, group=self.group, sync_op=False) + all2all_tasks.append((task, x_out)) + + expert_outputs = [] + for stage_id in range(self.num_local_experts): + if stage_id + 1 != self.num_local_experts: + stage_input = all2all_ins[stage_id + 1] + x_out, task = AlltoAll.apply(stage_input, group=self.group, sync_op=False) + all2all_tasks.append((task, x_out)) + + task, dispatched_input = all2all_tasks[stage_id] + task.wait() + expert_outputs_cur_stage = ( + recompute(self.forward_single_stage, dispatched_input, stage_id) + if self.recompute and self.training + else self.forward_single_stage(dispatched_input, stage_id) + ) + expert_outputs.append(expert_outputs_cur_stage) + + expert_output = paddle.stack(expert_outputs, axis=1) + return expert_output + + def forward( + self, + input: Tensor, + token_type_ids=None, + **kwargs, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """ + Forward pass through MoE layer. + + Args: + input: Input tensor of shape [s, d] + + Returns: + tuple: (output, combine_weights, router_loss, gate_logits) + """ + # assert len(input) == 1, "only single input Tensor supported" + if input.ndim == 3: + orig_shape = input.shape + input = input.reshape([-1, input.shape[-1]]) + else: + orig_shape = None + assert len(input.shape) == 2, f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}" + if token_type_ids is not None: + token_type_ids = token_type_ids.clone()[:, :-1] + if self.config.sequence_parallel: + token_type_ids = token_type_ids.reshape([-1]) + token_type_ids = ScatterOp.apply(token_type_ids) + token_type_ids.stop_gradient = True + + assert self.gate is not None + if hasattr(self, "rng") and self.rng.random() < self.all_to_all_dropout: + orig_shape_2 = input.shape + output = self.forward_experts(input) + output += self.gate.weight.sum() * 0.0 # hack for grad + output = output.reshape(orig_shape or orig_shape_2) # [e*1,c,m] + return output, None, 0 + + is_first_fwd = not framework._dygraph_tracer()._has_grad + gate_input = input + + ( + dispatched_input, + combine_weights, + dispatch_mask, + scatter_index, + router_loss, + gate_logits, + gate_prob, + ) = self.gate_and_dispatch(gate_input, token_type_ids) + + use_async = self.shared_experts is not None + if use_async: + dispatched_input, shared_out = AlltoAllAsync.apply( + dispatched_input, + input, # args to shared-experts + group=self.group, + fn=self.shared_experts, + is_first_fwd=is_first_fwd, + ) + else: + dispatched_input = AlltoAll.apply(dispatched_input, self.group) + + expert_out = ( + recompute(self.forward_experts, dispatched_input) + if self.recompute and self.training + else self.forward_experts(dispatched_input) + ) + + expert_out, router_loss2 = AlltoAllAsync.apply( + expert_out, + router_loss, + combine_weights, + dispatch_mask, + gate_logits, + gate_prob, + token_type_ids, + group=self.group, + fn=self.calc_router_loss_and_logging, + is_first_fwd=is_first_fwd, + ) + + combined_output = self.combine_expert_output(expert_out, combine_weights, scatter_index) + + if self.shared_experts is not None: + combined_output += shared_out + + if orig_shape: + combined_output = combined_output.clone().reshape(orig_shape[:-1] + [combined_output.shape[-1]]) + return combined_output, combine_weights, router_loss2, gate_logits diff --git a/paddleformers/nn/moe/moe_block.py b/paddleformers/nn/moe/moe_block.py new file mode 100644 index 00000000000..0051b62c842 --- /dev/null +++ b/paddleformers/nn/moe/moe_block.py @@ -0,0 +1,136 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +moe_layer_all_gather +""" + +from typing import List, Optional + +import paddle +from paddle import nn +from paddle.distributed.communication.group import Group + +from .abstract import MOELayerBase +from .moe_allgather_layer import MOEAllGatherLayerV2 +from .moe_alltoall_layer import MOEAlltoAllLayer +from .moe_ep_layer import QwenMoeBlock + + +def create_moe_block( + gate: nn.Layer = None, + experts: List[nn.Layer] = None, + layer_idx = -1, + shared_experts: Optional[List[nn.Layer]] = None, + group: Group = None, + recompute=False, + k=2, + enable_reverse_token_drop=False, + all_to_all_dropout=0, + group_experts=False, + use_expert_out_alltoall=True, # + use_padding=True, + dense_token_type=3, # considerd as dense tokens (no moe) + moe_statics=None, + moe_num_experts=None, + config=None, + expert_class=None, + use_shared_expert=True, + moe_mode="allgather", +) -> MOELayerBase: + if moe_mode == "allgather": + model = MOEAllGatherLayerV2( + gate, + experts, + layer_idx, + shared_experts, + group, + recompute, + k, + enable_reverse_token_drop, + all_to_all_dropout, + group_experts, + use_expert_out_alltoall, # + use_padding, + dense_token_type, # considerd as dense tokens (no moe) + moe_statics, + moe_num_experts, + ) + elif moe_mode == "alltoall": + model = MOEAlltoAllLayer( + gate, + experts, + layer_idx, + shared_experts, + group, + recompute, + k, + all_to_all_dropout, + group_experts, + moe_statics, + moe_num_experts, + ) + elif moe_mode == "qwen": + model = QwenMoeBlock( + config, + expert_class, + use_shared_expert + ) + else: + raise ValueError("Invalid moe_mode") + + return model + + +class MoEStatics(nn.Layer): + """ + Stores MoE (Mixture of Experts) statistics + and expert usage information. + """ + + def __init__(self, config, layer_idx): + """ + Initialize MoE statistics tracking. + + Args: + config: Model configuration containing MoE parameters + layer_idx: Index of the MoE layer in the model + """ + super().__init__() + self._cast_to_low_precision = False # 兼容develop分支paddle + self._cast_to_low_precison = False + num_experts = config.moe_num_experts[0] if config.multimodel_experts else config.moe_num_experts + if config.multimodel_experts: + assert ( + len(set(config.moe_num_experts)) == 1 + ), f"assume expert group has same size, got: {config.moe_num_experts}" + + with paddle.utils.unique_name.guard(f"mm_layer_{layer_idx}_"): + num_experts_groups = len(config.moe_num_experts) if config.multimodel_experts else 1 + p = self.create_parameter( + shape=[num_experts_groups, num_experts], + dtype="float32", + is_bias=True, + attr=paddle.ParamAttr(name=paddle.utils.unique_name.generate("corr_bias")), + ) + p.stop_gradient = True + self.e_score_correction_bias = p + self.e_score_correction_bias.is_distributed = True + p = paddle.zeros( + shape=[num_experts_groups, num_experts], + dtype="int64", + ) + p.stop_gradient = True + self.expert_usage = p diff --git a/paddleformers/nn/moe/moe_ep_gate.py b/paddleformers/nn/moe/moe_ep_gate.py new file mode 100644 index 00000000000..4011b7dc09a --- /dev/null +++ b/paddleformers/nn/moe/moe_ep_gate.py @@ -0,0 +1,617 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Tuple + +import paddle +import paddle.distributed as dist +import paddle.nn as nn +import paddle.nn.functional as F + +from paddleformers.utils.log import logger + + +class MoEGateMixin: + def gate_score_func(self, logits: paddle.Tensor) -> paddle.Tensor: + # [..., hidden_dim] -> [..., num_experts] + with paddle.amp.auto_cast(False): + scoring_func = getattr(self, "scoring_func", None) + if scoring_func == "softmax": + scores = F.softmax(logits.cast("float32"), axis=-1) + elif scoring_func == "sigmoid": + scores = F.sigmoid(logits.cast("float32")) + elif scoring_func == "tanh": + scores = F.tanh(logits.cast("float32")) + elif scoring_func == "relu": + scores = F.relu(logits.cast("float32")) + elif scoring_func == "gelu": + scores = F.gelu(logits.cast("float32")) + elif scoring_func == "leaky_relu": + scores = F.leaky_relu(logits.cast("float32")) + else: + logger.warning_once( + f"insupportable scoring function for MoE gating: {scoring_func}, use softmax instead" + ) + scores = F.softmax(logits.cast("float32"), axis=-1) + return scores + + def gumbel_rsample(self, logits: paddle.Tensor) -> paddle.Tensor: + gumbel = paddle.distribution.gumbel.Gumbel(0, 1) + return gumbel.rsample(logits.shape) + + def uniform_sample(self, logits: paddle.Tensor) -> paddle.Tensor: + uniform = paddle.distribution.uniform.Uniform(0, 1) + return uniform.sample(logits.shape) + + @paddle.no_grad() + def _one_hot_to_float(self, x, num_classes): + if x.dtype not in (paddle.int32, paddle.int64): + x = paddle.cast(x, paddle.int64) + return F.one_hot(x, num_classes=num_classes).cast(paddle.get_default_dtype()) + + @paddle.no_grad() + def _one_hot_to_int64(self, x, num_classes): + if x.dtype not in (paddle.int32, paddle.int64): + x = paddle.cast(x, paddle.int64) + return F.one_hot(x, num_classes=num_classes).cast(paddle.int64) + + @paddle.no_grad() + def _capacity( + self, + gates: paddle.Tensor, + capacity_factor: float, + max_capacity: int, + min_capacity: int, + ) -> paddle.Tensor: + """Calculate the capacity for each expert based on the gates and capacity factor. + + Args: + gates (paddle.Tensor): A tensor of shape [num_tokens, num_experts] representing the probability distribution + over experts for each token. + capacity_factor (float): A scalar float value representing the capacity factor for each expert. + min_capacity (int): A scalar integer value representing the minimum capacity for each expert. + + Returns: + int: A tensor value representing the calculated capacity for each expert. + """ + assert gates.ndim == 2, f"gates should be 2D, but got {gates.ndim}, {gates.shape}" + # gates has shape of SE + num_tokens = gates.shape[0] + num_experts = gates.shape[1] + capacity = int((num_tokens // num_experts) * capacity_factor) + if capacity < min_capacity: + capacity = min_capacity + if capacity > max_capacity: + capacity = max_capacity + assert capacity > 0, f"requires capacity > 0, capacity_factor: {capacity_factor}, input_shape: {gates.shape}" + + return capacity + + def _cal_aux_loss(self, gates, mask): + """ + Calculate auxiliary loss + + Args: + gates (paddle.Tensor): Represents the output probability of each expert. The shape is [batch_size, num_experts] + mask (paddle.Tensor): Represents whether each sample belongs to a certain expert. The shape is [batch_size, num_experts] + + Returns: + paddle.Tensor: The value of auxiliary loss. + + """ + # TODO: @DrownFish19 update aux_loss for Qwen2MoE and DeepSeekV2&V3 + me = paddle.mean(gates, axis=0) + ce = paddle.mean(mask.cast("float32"), axis=0) + if self.global_aux_loss: + me_list, ce_list = [], [] + dist.all_gather(me_list, me, group=self.group) + dist.all_gather(ce_list, ce, group=self.group) + + me_list[self.rank] = me + ce_list[self.rank] = ce + me = paddle.stack(me_list).mean(0) + ce = paddle.stack(ce_list).mean(0) + aux_loss = paddle.sum(me * ce) * float(self.num_experts) + return aux_loss + + def _cal_seq_aux_loss(self, gates, top_k, topk_idx) -> paddle.Tensor: + """ + Calculate sequence auxiliary loss. + + Args: + logits (paddle.Tensor): Model output. + + Returns: + paddle.Tensor: The value of sequence auxiliary loss. + """ + batch_size, seq_len, _ = gates.shape + ce = paddle.zeros([batch_size, self.num_experts]) + topk_idx = topk_idx.reshape([batch_size, -1]) + ce.put_along_axis_(indices=topk_idx, values=paddle.ones([batch_size, seq_len * top_k]), axis=1, reduce="add") + ce = ce / (seq_len * top_k / self.num_experts) + aux_loss = (ce * paddle.mean(gates, axis=1)).sum(axis=1).mean() + return aux_loss + + def _cal_z_loss(self, logits) -> paddle.Tensor: + """ + Calculate the z loss. + + Args: + logits (paddle.Tensor): Model output. The shape is [batch_size, num_experts]. + + Returns: + paddle.Tensor: The z loss value. + """ + l_zloss = paddle.logsumexp(logits, axis=1).square().mean() + return l_zloss + + def _cal_orthogonal_loss(self) -> paddle.Tensor: + """Gate weight orthogonal loss. + + Returns: + Paddle.Tensor: orthogonal loss + """ + weight = F.normalize(self.weight, axis=0) + orthogonal_loss = paddle.mean(paddle.square(paddle.matmul(weight.T, weight) - paddle.eye(self.num_experts))) + return orthogonal_loss + + +class PretrainedMoEGate(nn.Layer, MoEGateMixin): + def __init__(self, config, num_experts, expert_hidden_size, **kwargs): + super(PretrainedMoEGate, self).__init__() + + self.config = config + + self.num_experts = num_experts + self.expert_hidden_size = expert_hidden_size + + # force keep in float32 when using amp + self._cast_to_low_precision = False + + self.capacity_factor = kwargs.pop("capacity_factor", 1.0) + self.eval_capacity_factor = kwargs.pop("eval_capacity_factor", 1.0) + self.min_capacity = kwargs.pop("min_capacity", 1.0) + self.max_capacity = kwargs.pop("max_capacity", pow(2, 32)) + + self.group = kwargs.pop("group", None) + self.global_aux_loss = kwargs.pop("global_aux_loss", False) + if self.global_aux_loss: + assert self.group is not None, "group is required when global_aux_loss is True" + self.rank = dist.get_rank(self.group) + + self.expert_drop = kwargs.pop("expert_drop", False) + self.noisy_gate_policy = kwargs.pop("noisy_gate_policy", None) + self.drop_tokens = kwargs.pop("drop_tokens", True) + self.use_rts = kwargs.pop("use_rts", True) + self.top2_2nd_expert_sampling = kwargs.pop("top2_2nd_expert_sampling", True) + + self.drop_policy = kwargs.pop("drop_policy", "probs") + # Qwen2MoE: greedy + # DeepSeekV2&V3: group_limited_greedy for training, and noaux_tc for inference + self.topk_method = kwargs.pop("topk_method", "greedy") + self.top_k = kwargs.pop("top_k", 2) + self.n_group = kwargs.pop("n_group", 1) # for group_limited_greedy + self.topk_group = kwargs.pop("topk_group", 1) # for group_limited_greedy + self.norm_topk_prob = kwargs.pop("norm_topk_prob", False) + self.routed_scaling_factor = kwargs.pop("routed_scaling_factor", 1.0) + + def _priority(self, topk_idx: paddle.Tensor, capacity: int) -> paddle.Tensor: + """_summary_ + The priority is the cumulative sum of the expert indices. + + This method is used in hunyuan model + Args: + topk_idx (paddle.Tensor): [batch_size * seq_len, topk] + + Returns: + paddle.Tensor: cumsum locations + """ + _, k = topk_idx.shape + # Shape: [seq_len * k] + chosen_expert = topk_idx.reshape([-1]) + # Shape: [seq_len * k, num_experts]. + token_priority = F.one_hot(chosen_expert, self.num_experts).cast(paddle.int32) + token_priority = paddle.logical_and(token_priority > 0, token_priority.cumsum(axis=0) <= capacity) + # Shape: [seq_len, num_experts]. + token_priority = token_priority.reshape([-1, k, self.num_experts]).sum(axis=1) + + return (token_priority > 0.0).astype("float32") + + def _topk_greedy(self, scores: paddle.Tensor, k: int) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + """ + topk_weight, topk_idx = paddle.topk(scores, k=k, axis=-1, sorted=True) + return topk_weight, topk_idx + + def _topk_group_limited_greedy( + self, scores: paddle.Tensor, k: int, n_group: int, topk_group: int + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts in each group + n_groups (int): the number of groups for all experts + topk_group (int): the number of groups selected + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + + Note: the group size is normal greater than the number of k + """ + bsz_seq_len, n_experts = scores.shape + assert n_experts % n_group == 0, "n_experts must be divisible by n_groups" + + group_scores = scores.reshape([0, n_group, -1]).max(axis=-1) # [n, n_group] + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group] + group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.to_tensor(1.0), axis=-1) # fmt:skip + score_mask = ( + group_mask.unsqueeze(-1).expand([bsz_seq_len, n_group, n_experts // n_group]).reshape([bsz_seq_len, -1]) + ) # [n, e] + tmp_scores = scores * score_mask # [n, e] + topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=True) + + return topk_weight, topk_idx + + def _topk_noaux_tc( + self, scores: paddle.Tensor, k: int, n_group: int, topk_group: int + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts in each group + n_groups (int): the number of groups for all experts + topk_group (int): the number of groups selected + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + + Note: the group size is normal greater than the number of k + """ + bsz_seq_len, n_experts = scores.shape + assert n_experts % n_group == 0, "n_experts must be divisible by n_groups" + + assert self.e_score_correction_bias is not None, "e_score_correction_bias is None" + scores_for_choice = scores.reshape([bsz_seq_len, -1]) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.reshape([bsz_seq_len, self.n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1) + ) # fmt:skip [n, n_group] + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group] + group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.to_tensor(1.0, dtype="float32"), axis=-1) # fmt:skip + score_mask = ( + group_mask.unsqueeze(-1).expand([bsz_seq_len, n_group, n_experts // n_group]).reshape([bsz_seq_len, -1]) + ) # [n, e] + tmp_scores = scores_for_choice * score_mask # [n, e] + topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=True) + topk_weight = scores.take_along_axis(topk_idx, axis=1) if not self.training else topk_weight + + return topk_weight, topk_idx + + def top1gating( + self, + logits: paddle.Tensor, + used_token: paddle.Tensor = None, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Implements Top1Gating on logits.""" + if self.noisy_gate_policy == "RSample": + logits += self.gumbel_rsample(logits.shape) + + gates = self.gate_score_func(logits=logits) + capacity = self._capacity(gates, self.capacity_factor, self.max_capacity, self.min_capacity) + + # Create a mask for 1st's expert per token + # noisy gating + # Only save the position of the maximum value + indices1_s = paddle.argmax(logits if self.noisy_gate_policy == "RSample" else gates, axis=1) + # Convert the position of the maximum value to a one-hot vector [s, e] + mask1 = self._one_hot_to_float(indices1_s, num_classes=self.num_experts) + + # mask only used tokens + if used_token is not None: + mask1 = paddle.einsum( + "s,se->se", used_token, mask1 + ) # Element-wise multiply used_token with mask1 to obtain a new mask1 + + # gating decisions + exp_counts = paddle.sum(mask1, axis=0) # Calculate the number of tokens for each expert + + # if we don't want to drop any tokens + if not self.drop_tokens: + new_capacity = paddle.max(exp_counts) # Calculate the number of tokens for each expert + # Communicate across expert processes to pick the maximum capacity. + if self.group is not None: + dist.all_reduce( + new_capacity, op=dist.ReduceOp.MAX, group=self.group + ) # Calculate the maximum value among expert processes + # Make sure the capacity value does not exceed the number of tokens. + capacity = int(min(new_capacity, paddle.tensor(mask1.size(0)))) + + l_aux = self._cal_aux_loss(gates, mask1) + l_zloss = self._cal_z_loss(logits) + + # Random Token Selection + if self.use_rts: + mask1_rand = mask1 * self.uniform_sample(mask1) + else: + mask1_rand = mask1 + + assert ( + logits.shape[0] >= self.min_capacity + ), "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size." + + _, top_idx = paddle.topk(mask1_rand, k=capacity, axis=0) # Select top_capacity tokens + + new_mask1 = mask1 * paddle.zeros_like(mask1).put_along_axis( + top_idx, paddle.to_tensor(1.0, dtype="float32"), axis=0 + ) + mask1 = new_mask1 + + # Compute locations in capacity buffer + locations1 = paddle.cumsum(mask1, axis=0) - 1 # Compute the position of each token in mask1 + + # Store the capacity location for each token + locations1_s = paddle.sum(locations1 * mask1, axis=1).cast(paddle.int64) + + # Normalize gate probabilities + mask1_float = mask1.cast(paddle.float32) + gates = gates / gates * mask1_float + + locations1_sc = self._one_hot_to_float(locations1_s, capacity) + combine_weights = paddle.einsum("se,sc->sec", gates, locations1_sc) + dispatch_mask = combine_weights.cast(paddle.bool).detach() + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + + def top2gating( + self, + logits: paddle.Tensor, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + # everything is in fp32 in this function + gates = self.gate_score_func(logits=logits) + + # Create a mask for 1st's expert per token. + indices1_s = paddle.argmax(gates, axis=1) # [S, 1] + mask1 = self._one_hot_to_int64(indices1_s, self.num_experts) # [S, E] + + if self.top2_2nd_expert_sampling: + # Create a mask for 2nd's expert per token using Gumbel-max trick. + # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ + logits += self.gumbel_rsample(logits) + + # Replace top-expert with min value + logits_except1 = logits.masked_fill(mask1.cast(paddle.bool), float("-inf")) # [S, E] + indices2_s = paddle.argmax(logits_except1, axis=1) # [S, 1] + mask2 = self._one_hot_to_int64(indices2_s, self.num_experts) # [S, E] + + # Note: mask1 and mask2 can be combined to form a single mask. + # mask = paddle.concat([mask1, mask2], axis=0) + # locations = paddle.cumsum(mask, axis=0) - 1 + # locations1, locations2 = locations.split(2, axis=0) + # Compute locations in capacity buffer. + locations1 = paddle.cumsum(mask1, axis=0) - 1 # [S, E] + locations2 = paddle.cumsum(mask2, axis=0) - 1 # [S, E] + # Update 2nd's location by accounting for locations of 1st. + locations2 += paddle.sum(mask1, axis=0, keepdim=True) + + l_aux = self._cal_aux_loss(gates, mask1) + l_zloss = self._cal_z_loss(logits) + + # gating decisions + exp_counts = paddle.sum(mask1 + mask2, axis=0) + if self.drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = self._capacity(gates, self.capacity_factor, self.max_capacity, self.min_capacity) + # Remove locations outside capacity from mask. + mask1 *= (locations1 < capacity).cast(paddle.int64) + mask2 *= (locations2 < capacity).cast(paddle.int64) + else: + # Do not drop tokens - set capacity according to current expert assignments + new_capacity = paddle.max(exp_counts) + if self.group is not None: + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=self.group) + capacity = int(new_capacity) + + # Store the capacity location for each token. + locations1_s = paddle.sum(locations1 * mask1, axis=1) + locations2_s = paddle.sum(locations2 * mask2, axis=1) + + # Normalize gate probabilities + mask1_float = mask1.cast(paddle.float32) + mask2_float = mask2.cast(paddle.float32) + gates1_s = paddle.einsum("se,se->s", gates, mask1_float) + gates2_s = paddle.einsum("se,se->s", gates, mask2_float) + denom_s = gates1_s + gates2_s + # Avoid divide-by-zero + denom_s = paddle.clip(denom_s, min=paddle.finfo(denom_s.dtype).eps) + gates1_s /= denom_s + gates2_s /= denom_s + + # Calculate combine_weights and dispatch_mask + gates1 = paddle.einsum("s,se->se", gates1_s, mask1_float) + gates2 = paddle.einsum("s,se->se", gates2_s, mask2_float) + locations1_sc = self._one_hot_to_float(locations1_s, capacity) + locations2_sc = self._one_hot_to_float(locations2_s, capacity) + combine1_sec = paddle.einsum("se,sc->sec", gates1, locations1_sc) + combine2_sec = paddle.einsum("se,sc->sec", gates2, locations2_sc) + combine_weights = combine1_sec + combine2_sec + dispatch_mask = combine_weights.cast(paddle.bool) + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + + def topkgating( + self, + gates: paddle.Tensor, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Implements TopKGating on logits.""" + batch_size, seq_len, d_model = gates.shape + gates_ori = gates + gates = gates.reshape([-1, d_model]) + + l_zloss = self._cal_z_loss(gates) + + # get topk gates + if self.topk_method == "greedy": + top_gate, top_idx = self._topk_greedy(gates, k=self.top_k) + elif self.topk_method == "group_limited_greedy": + top_gate, top_idx = self._topk_group_limited_greedy( + gates, k=self.top_k, n_group=self.n_group, topk_group=self.topk_group + ) + elif self.topk_method == "noaux_tc": + top_gate, top_idx = self._topk_noaux_tc( + gates, k=self.top_k, n_group=self.n_group, topk_group=self.topk_group + ) + # norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = top_gate.sum(axis=-1, keepdim=True) + 1e-20 + top_gate = top_gate / denominator + top_gate = top_gate * self.routed_scaling_factor + + # get topk mask + mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0, dtype="float32"), axis=1) + if hasattr(self.config, "seq_aux") and self.config.seq_aux: + l_aux = self._cal_seq_aux_loss(gates_ori, self.top_k, top_idx) + else: + l_aux = self._cal_aux_loss(gates, mask) + + exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0) + + if self.drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = self._capacity( + gates, + self.capacity_factor * self.top_k, + self.max_capacity, + self.min_capacity, + ) + + # update mask and locations by capacity + if self.drop_policy == "probs": + topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1) + capacity_probs, capacity_indices = paddle.topk(topk_masked_gates, k=capacity, axis=0, sorted=False) + token_priority = self._priority(capacity_indices, capacity) + + elif self.drop_policy == "position": + token_priority = self._priority(top_idx, capacity) + else: + raise ValueError(f"Invalid drop_policy: {self.drop_policy}") + else: + # Do not drop tokens - set capacity according to current expert assignments + local_capacity = paddle.max(exp_counts) + if self.group is not None: + dist.all_reduce(local_capacity, op=dist.ReduceOp.MAX, group=self.group) + capacity = int(local_capacity) + token_priority = self._priority(top_idx, capacity) + + # normalize gates + # gates_masked is equal to top_gate. + gates_masked = gates * mask + # if self.training: + gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True) + denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps) + if self.norm_topk_prob: + gates_masked = gates_masked / denom_s + gates_masked *= self.routed_scaling_factor + + return ( + capacity, + gates_masked.take_along_axis(top_idx, axis=-1), + top_idx, + token_priority.take_along_axis(top_idx, axis=-1), + l_aux, + l_zloss, + ) + + def topkgating_nodrop(self, gates: paddle.Tensor): + """Implements TopKGating on logits.""" + batch_size, seq_len, d_model = gates.shape + gates_ori = gates + gates = gates.reshape([-1, d_model]) + + l_zloss = self._cal_z_loss(gates) + + # get topk gates + if self.topk_method == "greedy": + top_gate, top_idx = self._topk_greedy(gates, k=self.top_k) + elif self.topk_method == "group_limited_greedy": + top_gate, top_idx = self._topk_group_limited_greedy( + gates, k=self.top_k, n_group=self.n_group, topk_group=self.topk_group + ) + elif self.topk_method == "noaux_tc": + top_gate, top_idx = self._topk_noaux_tc( + gates, k=self.top_k, n_group=self.n_group, topk_group=self.topk_group + ) + # norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = top_gate.sum(axis=-1, keepdim=True) + 1e-20 + top_gate = top_gate / denominator + top_gate = top_gate * self.routed_scaling_factor + + # get topk mask + mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) + + if hasattr(self.config, "seq_aux") and self.config.seq_aux: + l_aux = self._cal_seq_aux_loss(gates_ori, self.top_k, top_idx) + else: + l_aux = self._cal_aux_loss(gates, mask) + + exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0) + topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1) + return topk_masked_gates, mask, exp_counts, l_aux, l_zloss + +class QwenMoeGate(PretrainedMoEGate): + def __init__(self, config, num_experts, expert_hidden_size, **kwargs): + super().__init__(config, num_experts, expert_hidden_size, **kwargs) + # [hidden_size, n_expert] + self.weight = paddle.create_parameter( + shape=[expert_hidden_size, num_experts], + dtype=paddle.get_default_dtype(), + is_bias=False, + default_initializer=nn.initializer.Constant(1.0), + ) + + def forward(self, hidden_states): + """ + Args: + hidden_states (_type_): [batch_size * seq_len, hidden_size] + """ + _, _, h_dim = hidden_states.shape + + # compute gating score + logits = F.linear(hidden_states, self.weight, None) + + with paddle.amp.auto_cast(False): + scores = self.gate_score_func(logits=logits) + scores = scores.cast(paddle.get_default_dtype()) + + capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.topkgating(scores) + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss \ No newline at end of file diff --git a/paddleformers/nn/moe/moe_ep_layer.py b/paddleformers/nn/moe/moe_ep_layer.py new file mode 100644 index 00000000000..e3cd5c5cae2 --- /dev/null +++ b/paddleformers/nn/moe/moe_ep_layer.py @@ -0,0 +1,417 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Any, List, Tuple + +import numpy as np +import paddle +import paddle.distributed as dist +from paddle import Tensor, nn +from paddle.distributed.communication.group import Group +import paddle.nn.functional as F + +from .moe_ep_gate import PretrainedMoEGate, QwenMoeGate +from .token_dispatcher import MoEFlexTokenDispatcher +from .abstract import MOELayerBase + + +def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): + """ + Rearranges the input tensor `x` based on gate results, truncates it according to the specified capacity, and performs padding. + + Args: + x (Tensor)[Seq, Dim]: The input tensor. + dispatch_mask (List[Tensor[Seq, 1], Tensor[Seq, 1]]): A list of dispatch masks. + scatter_index (Union[List[Tensor[Seq,], Tensor[Seq]], Tensor[Seq, 2]]): A list or tensor representing scatter indices. + num_experts (int): The number of experts. + capacity (int): The capacity size. + + Returns: + Tensor [Expert*Capacity, Dim]: The output tensor after dispatching. + """ + output = None + orig_dtype = x.dtype + if isinstance(scatter_index, paddle.Tensor): + scatter_index = scatter_index.unbind(1) + for i_scatter_index, i_dispatch_mask in zip(scatter_index, dispatch_mask): + init_output = paddle.zeros([num_experts * capacity, x.shape[-1]], dtype="float32") + updates = x * i_dispatch_mask.cast(x.dtype) + if output is None: + output = paddle.scatter( + init_output, + i_scatter_index, + updates, + overwrite=False, + ) + else: + output = output + paddle.scatter( + init_output, + i_scatter_index, + updates, + overwrite=False, + ) + if output.dtype != orig_dtype: + output = output.cast(orig_dtype) + return output + + +def combining(x, combine_weights, scatter_index): + """ + Performs combination and aggregation operations on the input matrix. + + Args: + x: Tensor[num_experts * capacity, dim] - The input matrix to be processed, where the last dimension represents the number of features. + combine_weights: Union[List[Tensor[seq, 1], Tensor[seq, 1]], Tensor[seq, 2, 1]] - A list or tensor containing combination weights for each feature. + scatter_index: Union[List[Tensor[seq], Tensor[seq]], Tensor[seq, 2]] - A tuple of indices indicating which elements are to be aggregated, where the first element is the row index and the second element is the column index. + + Returns: + Tensor: The output matrix after combination and aggregation, with a shape of [n, dim * num_features], where n is the number of samples in the input matrix. + """ + + dim = x.shape[-1] + if isinstance(scatter_index, (list, tuple)): + scatter_index = paddle.concat([i.unsqueeze([-1]) for i in scatter_index], -1) + scatter_index = scatter_index.reshape([-1]) + num_k = len(combine_weights) if isinstance(combine_weights, (list, tuple)) else combine_weights.shape[-1] + x = paddle.gather(x, scatter_index).reshape([-1, num_k, dim]) # [seq,2,dim] + if isinstance(combine_weights, (list, tuple)): + combine_weights = paddle.concat(combine_weights, -1).unsqueeze([1]) + return paddle.matmul(combine_weights, x).squeeze(1) # [seq,1,2] @ [seq,2,dim] -> [seq,1,dim] + + +class _AllToAll(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx: Any, + output_shape: List, + input: Tensor, + out_split_sizes: List = None, + in_split_sizes: List = None, + group: Group = None, + ) -> Tensor: # type: ignore + """ + All-to-all communication in the group. + Args: + ctx (Any): Context object. + output_shape (List): Output shape. + input (Tensor): Input tensor. + out_split_sizes (List): Output split sizes. + in_split_sizes (List): Input split sizes. + group (Group): The group object. + Returns: + Tensor: Output tensor. + """ + + ctx.group = group + ctx.input_shape = input.shape + ctx.out_split_sizes = out_split_sizes + ctx.in_split_sizes = in_split_sizes + + # return input + if dist.get_world_size(group) <= 1: + return input + + output = paddle.empty(output_shape, dtype=input.dtype) + task = dist.alltoall_single( + output, + input, + out_split_sizes=out_split_sizes, + in_split_sizes=in_split_sizes, + sync_op=False, + group=group, + ) + task.wait() + + return output + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor]: + """ + Aggregates gradient information from all input tensors into a single tensor. + Args: + ctx (Any): The context object used to store information that needs to be passed. + *grad_output (Tensor): A list of input tensors whose gradients are to be aggregated. + Returns: + Tuple[Tensor]: A tuple containing a tensor that holds the gradients of all input tensors. + """ + # return grad_output + return _AllToAll.apply(ctx.input_shape, *grad_output, ctx.in_split_sizes, ctx.out_split_sizes, ctx.group) + + +class MoELayer(MOELayerBase): + def __init__( + self, + config, + moe_num_experts: int, + expert_class: nn.Layer, + expert_kwargs: dict, + gate: PretrainedMoEGate, + capacity: int = 1.0, + moe_group: str = "data", + all_to_all_dropout=0.0, + ): + super().__init__() + + self.config = config + + self.moe_num_experts = moe_num_experts + self.capacity = capacity + + try: + dist.fleet.get_hybrid_communicate_group() + is_fleet_init = True + except AttributeError: + is_fleet_init = False + + if ( + is_fleet_init + and dist.fleet.get_hybrid_communicate_group().get_data_parallel_world_size() > 1 + and moe_group == "data" + ): + self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + self.moe_rank = dist.get_rank(self.moe_group) + self.moe_rank = 0 if self.moe_rank < 0 else self.moe_rank + self.expert_parallel_degree = dist.get_world_size(self.moe_group) + self.expert_parallel_degree = 1 if self.expert_parallel_degree < 0 else self.expert_parallel_degree + self.moe_num_experts_per_device = self._parse_moe_expert_parallel( + self.moe_num_experts, self.expert_parallel_degree + ) + self.is_dummy_moe = False if self.expert_parallel_degree > 1 else True + else: + # when moe_group is dummy, we don't need to use all_to_all + self.moe_group = None + self.moe_rank = 0 + self.expert_parallel_degree = 1 + self.moe_num_experts_per_device = self.moe_num_experts + self.is_dummy_moe = True + + self.all_to_all_dropout = all_to_all_dropout + self.enable_recompute = False + + self.experts = nn.LayerList([]) + for i in range(self.moe_num_experts): + if i // self.moe_num_experts_per_device == self.moe_rank: + self.experts.append(expert_class(**expert_kwargs)) + else: + self.experts.append(None) + + self.gate = gate + self.gate.group = self.moe_group + self._post_init() + + def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree): + assert ( + moe_num_experts >= expert_parallel_degree + ), f"expert moe_num_experts={moe_num_experts} >= moe_world_size={expert_parallel_degree}" + assert ( + moe_num_experts % expert_parallel_degree == 0 + ), f"expert moe_num_experts={moe_num_experts} % moe_world_size={expert_parallel_degree} == 0" + moe_num_experts_per_device = moe_num_experts // expert_parallel_degree + return moe_num_experts_per_device + + def _post_init(self): + for p in self.gate.parameters(): + p.is_gate = True + + for k in self.experts: + if k is not None: + for p in k.parameters(): + p.expert = not self.is_dummy_moe + p.no_sync = not self.is_dummy_moe + # logger.info(f"expert param={p.name}, no-sync={p.no_sync}") + + def forward( + self, + hidden_state: paddle.Tensor, + ): + """MoE Layer forward function + 1. Gate Forward. + 2. Dispatch export. + 3. Experts Forward. + + Args: + hidden_state: MoE Layer input + + Returns: + final_out: MoE Layer main output. + l_aux: MoE auxiliary loss. l_zloss: MoE z loss.""" + batch_size, seq_len, d_model = hidden_state.shape + + reshaped_input = hidden_state.reshape([-1, d_model]) + + # self.l_aux : + # topk_weight : se + # topk_ids : sk + # token_priority : se + # self.exp_counts : + capacity, topk_weight, topk_ids, token_priority, l_aux, l_zloss = self.gate(hidden_state) + + """MoE expert dispatch from: https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py""" + cnts = paddle.zeros([topk_ids.shape[0], len(self.experts)], dtype=topk_ids.dtype) + cnts = cnts.put_along_axis(topk_ids, 1, axis=1) + + tokens_per_expert = cnts.sum(axis=0) + idxs = topk_ids.reshape([topk_ids.shape[0] * topk_ids.shape[1]]).argsort() + sorted_tokens = reshaped_input[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.detach() + sorted_tokens_shape = sorted_tokens.shape + + if self.expert_parallel_degree > 1: + tokens_per_ep_rank = tokens_per_expert.reshape([self.expert_parallel_degree, -1]).sum(axis=1) + tokens_per_expert_group = _AllToAll.apply( + [tokens_per_expert.shape[0]], tokens_per_expert, group=self.moe_group + ) + output_splits = ( + tokens_per_expert_group.reshape([self.expert_parallel_degree, -1]).sum(axis=1).cpu().tolist() + ) + input_split_sizes = tokens_per_ep_rank.cpu().tolist() + gathered_tokens = _AllToAll.apply( + [tokens_per_expert_group.sum(axis=0).cpu().item(), sorted_tokens.shape[1]], + sorted_tokens, + out_split_sizes=output_splits, + in_split_sizes=input_split_sizes, + group=self.moe_group, + ) + + tokens_per_expert_post_gather = tokens_per_expert_group.reshape( + [self.expert_parallel_degree, self.moe_num_experts_per_device] + ).sum(axis=0) + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) + s = 0 + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): + gatherd_idxs[s : s + k] = i % self.moe_num_experts_per_device + s += k + gatherd_idxs = gatherd_idxs.argsort() + sorted_tokens = gathered_tokens[gatherd_idxs] + tokens_per_expert = tokens_per_expert_post_gather + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.moe_rank * self.moe_num_experts_per_device] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + outs = paddle.concat(outputs, axis=0) if len(outputs) > 0 else paddle.to_tensor(0, dtype=sorted_tokens.dtype) + if self.expert_parallel_degree > 1: + new_x = paddle.empty_like(outs) + new_x[gatherd_idxs] = outs + gathered_tokens = _AllToAll.apply( + sorted_tokens_shape, + new_x, + out_split_sizes=input_split_sizes, + in_split_sizes=output_splits, + group=self.moe_group, + ) + outs = gathered_tokens + + new_x = paddle.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.reshape(topk_ids.shape + [-1]) + .astype(topk_weight.dtype) + .multiply_(topk_weight.unsqueeze(-1)) + .multiply_(token_priority.unsqueeze(-1)) + .sum(axis=1) + .astype(new_x.dtype) + .reshape([batch_size, seq_len, -1]) + ) + + return final_out, l_aux, l_zloss + + +class MoEFlexTokenLayer(nn.Layer): + def __init__(self, config, moe_num_experts, expert_class, expert_kwargs, gate, moe_group): + + super().__init__() + self.config = config + self.moe_group = moe_group + self.ep_size = dist.get_world_size(self.moe_group) + self.moe_router_topk = gate.top_k + self.moe_num_experts = moe_num_experts + self.num_local_experts = moe_num_experts // self.ep_size + self.token_dispatcher = MoEFlexTokenDispatcher( + self.num_local_experts, self.moe_router_topk, self.moe_num_experts, moe_group + ) + + self.experts = nn.LayerList([expert_class(**expert_kwargs)] * self.num_local_experts) + self.router = gate + + def expert_forward(self, dispatched_input, tokens_per_expert): + outputs = [] + tokens_per_expert = tokens_per_expert.tolist() + # print(f"all tokens: {sum(tokens_per_expert)}, detail: {tokens_per_expert}") + chunks = paddle.split(dispatched_input, num_or_sections=tokens_per_expert, axis=0) + for chunk, expert in zip(chunks, self.experts): + chunk = chunk.contiguous() + # assert chunk.shape[0] != 0, "Cannot dispatch empty input" + outputs += [expert(chunk)] + + return paddle.concat(outputs, axis=0) + + def forward(self, hidden_states: paddle.Tensor): + _, _, d_model = hidden_states.shape + # reshaped_input = hidden_states.reshape([-1, d_model]) + probs, routing_map, l_aux, l_zloss = self.router(hidden_states) + (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation( + hidden_states, probs, routing_map + ) + expert_output = self.expert_forward(dispatched_input, tokens_per_expert) + output, _ = self.token_dispatcher.token_unpermutation(expert_output, None) + return output, l_aux, l_zloss + +class QwenMoeBlock(MoELayer): + def __init__(self, config, expert_class, use_shared_expert=True): + gate = QwenMoeGate( + config, + config.num_experts, + config.hidden_size, + top_k=config.num_experts_per_tok, + drop_tokens=False, + ) + + super().__init__( + config, + moe_num_experts=config.num_experts, + expert_class=expert_class, + expert_kwargs={"config": config}, + gate=gate, + capacity=2.0, + ) + + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.use_shared_expert = use_shared_expert + + if use_shared_expert: + self.shared_expert = expert_class(config, is_shared=True) + self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias_attr=False) + + def forward(self, hidden_states): + final_hidden_states, l_aux, l_zloss = super().forward(hidden_states) + + if self.use_shared_expert: + shared_expert_output = self.shared_expert(hidden_states) + shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output + final_hidden_states = final_hidden_states + shared_expert_output + + return final_hidden_states, l_aux \ No newline at end of file diff --git a/paddleformers/nn/moe/token_dispatcher.py b/paddleformers/nn/moe/token_dispatcher.py new file mode 100644 index 00000000000..9fb4c335d1a --- /dev/null +++ b/paddleformers/nn/moe/token_dispatcher.py @@ -0,0 +1,284 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 DeepSeek +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Optional, Tuple + +import paddle +from paddle.distributed.communication.group import Group + +from .fused_a2a import fused_combine, fused_dispatch +from .utils import permute, unpermute + + +class _DispatchManager(ABC): + """ + A manager class to handle dispatch and combine processes for MoE models. + + DispatcherManager handles token dispatching according to the routing_map of format + [num_local_tokens, world_size, num_instances]. The routing_map is a 3D tensor where each + element indicates whether a token should be sent to a specific rank. + + num_instances is the maximum number of tokens instances dispatched into a target rank, it + can be the number of local experts, or the size of sub_group. + """ + + @abstractmethod + def setup_metadata(self, routing_map: paddle.Tensor, probs: paddle.Tensor): + """Set up metadata of routing_map and probs.""" + pass + + @abstractmethod + def dispatch(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + """Dispatch the hidden_states according to the routing_map.""" + pass + + @abstractmethod + def combine(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + """Combine the hidden_states after expert processing.""" + pass + + @abstractmethod + def get_dispatched_metadata(self) -> paddle.Tensor: + """Get the metadata of the dispatched hidden_states.""" + pass + + @abstractmethod + def get_permuted_hidden_states_by_experts(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + """Get the permuted hidden states by instances.""" + pass + + @abstractmethod + def get_restored_hidden_states_by_experts(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + """Get the restored hidden states by instances.""" + pass + + +class _DeepepManager(_DispatchManager): + """ + A manager class to handle fused all-to-all communication processes for MoE models using + DeepEP backend. See https://github.com/deepseek-ai/deepep for more details. + + The workflow of the DeepEP dispatcher is: + (1) setup_metadata(): Process routing map and probabilities to prepare dispatch metadata + (2) dispatch(): + - Use fused kernel to permute tokens and perform all-to-all communication in single step + (3) get_permuted_hidden_states_by_instances(): + - Convert routing map and probabilities to multihot format + - Permute tokens using fused kernel + (4) get_restored_hidden_states_by_instances(): + - Reverse permutation using fused kernel + (5) combine(): + - Reverse process using fused kernel to unpermute and perform all-to-all in single step + + This implementation uses fused communication kernels (fused_dispatch/fused_combine) that + combine permutation and communication operations for improved efficiency compared to + separate permute+alltoall steps. + """ + + def __init__( + self, + group: Group, + router_topk: int, + num_experts: int = None, + num_local_experts: int = None, + ): + self.group = group + self.router_topk = router_topk + self.num_experts = num_experts + self.num_local_experts = num_local_experts + + # Metadata + self.token_indices = None + self.token_probs = None + # Handle used for combine operation + self.handle = None + + if fused_dispatch is None: + raise ImportError("DeepEP is not supported in your paddlepaddle whl package.") + + def setup_metadata(self, routing_map: paddle.Tensor, probs: paddle.Tensor): + num_tokens = routing_map.shape[0] + + routing_map = routing_map.reshape([num_tokens, self.num_experts]) + probs = probs.reshape([num_tokens, self.num_experts]) + # Convert the format of routing map from multihot to indices. + self.token_probs, self.token_indices = paddle.topk(probs, self.router_topk, axis=-1) + + def dispatch(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + hidden_states, dispatched_probs, states = fused_dispatch( + hidden_states, self.token_indices, self.token_probs, self.num_experts, self.group + ) + self.handle = states["handle"] + self.tokens_per_expert = states["tokens_per_expert"] + self.dispatched_indices = states["dispatched_indices"] + self.dispatched_probs = dispatched_probs + + return hidden_states + + def _indices_to_multihot(self, indices, probs): + """ + Converts a tensor of indices to a multihot vector. + + Args: + indices (paddle.Tensor): [num_tokens, topk] token indices, where -1 means masked out. + probs (paddle.Tensor): [num_tokens, topk] token probabilities. + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: + - routing_map: Multihot vector. + - probs: Multihot probabilities. + """ + batch_size = indices.shape[0] + multihot_routing_map = paddle.zeros((batch_size, self.num_local_experts), dtype=paddle.int64) + + multihot_probs = paddle.zeros((batch_size, self.num_local_experts), dtype=paddle.float32) + + mask = indices != -1 + valid_indices = indices[mask] + row_indices = paddle.arange(batch_size).repeat_interleave(mask.sum(axis=1)) + multihot_routing_map[row_indices, valid_indices] = 1 + multihot_probs[row_indices, valid_indices] = probs[mask] + return multihot_routing_map.cast(paddle.bool), multihot_probs + + def get_dispatched_metadata(self) -> paddle.Tensor: + return self.dispatched_indices, self.dispatched_probs + + def get_number_of_tokens_per_expert(self) -> paddle.Tensor: + """ + Get the number of tokens per expert. + """ + return self.tokens_per_expert + + def combine(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + hidden_states = fused_combine(hidden_states, self.group, self.handle) + # Release the handle after combine operation + self.handle = None + return hidden_states + + def get_permuted_hidden_states_by_experts(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + self.dispatched_routing_map, self.dispatched_probs = self._indices_to_multihot( + self.dispatched_indices, self.dispatched_probs + ) + self.hidden_shape_before_permute = hidden_states.shape + hidden_states, self.reversed_mapping_for_combine = permute( + hidden_states, + self.dispatched_routing_map, + num_out_tokens=sum(self.tokens_per_expert), + ) + return hidden_states + + def get_restored_hidden_states_by_experts(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + input_dtype = hidden_states.dtype + assert self.dispatched_probs.dtype == paddle.float32, "DeepEP only supports float32 probs" + hidden_states = unpermute( + hidden_states, + self.reversed_mapping_for_combine, + restore_shape=self.hidden_shape_before_permute, + routing_map=self.dispatched_routing_map, + probs=self.dispatched_probs, + ) + return hidden_states.to(input_dtype) + + +class MoETokenDispatcher: + """ + MoE Token Dispatcher + """ + + def __init__(self, ep_group) -> None: + """ + Initialize the MoE Token Dispatcher. + """ + self._ep_group = ep_group + + @property + def ep_group(self): + """Get expert model parallel group.""" + return self._ep_group + + @property + def ep_size(self): + """Get expert model parallel world_size.""" + return self.ep_group.world_size + + @abstractmethod + def token_permutation(self, tokens: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor): + """Dispatch tokens to experts. + + Args: + tokens (paddle.Tensor): Input tokens. + probs (paddle.Tensor): The routing probability tensor [num_tokens, num_experts]. + routing_map (paddle.Tensor): Token to expert mapping tensor. + + Returns: + paddle.Tensor: Tokens tensor. + """ + raise NotImplementedError("Dispatch function not implemented.") + + @abstractmethod + def token_unpermutation(self, expert_output: paddle.Tensor, bias: paddle.Tensor = None): + """Restores the expert output to its original ordering. + + Args: + expert_output (paddle.Tensor): The output tensor from the expert models. + bias (paddle.Tensor): The bias tensor. + + Returns: + (paddle.Tensor, paddle.Tensor): Unpermuted activation and optional bias. + """ + raise NotImplementedError("Restore function not implemented.") + + +class MoEFlexTokenDispatcher(MoETokenDispatcher): + """ + Flexible token dispatcher for MoE models with Efficient-A2A communication kernels. + """ + + def __init__(self, num_local_experts: int, moe_router_topk: int, num_moe_experts: int, ep_group: Group): + super().__init__(ep_group) + + self.num_local_experts = num_local_experts + assert self.ep_size > 1, "Flex token dispatcher requires EP > 1" + self._comm_manager = _DeepepManager( + group=self.ep_group, + router_topk=moe_router_topk, + num_experts=num_moe_experts, + num_local_experts=self.num_local_experts, + ) + + def token_permutation( + self, hidden_states: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + self.hidden_shape = hidden_states.shape + hidden_states = hidden_states.view([-1, self.hidden_shape[-1]]) + + self._comm_manager.setup_metadata(routing_map, probs) + hidden_states = self._comm_manager.dispatch(hidden_states) + global_input_tokens = self._comm_manager.get_permuted_hidden_states_by_experts(hidden_states) + tokens_per_expert = self._comm_manager.get_number_of_tokens_per_expert() + + return global_input_tokens, tokens_per_expert + + def token_unpermutation( + self, hidden_states: paddle.Tensor, bias: Optional[paddle.Tensor] = None + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]: + assert bias is None, "Bias is not supported in MoEFlexTokenDispatcher" + hidden_states = self._comm_manager.get_restored_hidden_states_by_experts(hidden_states) + hidden_states = self._comm_manager.combine(hidden_states) + + hidden_states = hidden_states.reshape(self.hidden_shape) + return hidden_states, None diff --git a/paddleformers/nn/moe/topk_gate.py b/paddleformers/nn/moe/topk_gate.py new file mode 100644 index 00000000000..d62b53eb1a8 --- /dev/null +++ b/paddleformers/nn/moe/topk_gate.py @@ -0,0 +1,575 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +top2gate +""" + +from functools import partial +from typing import Tuple + +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import Tensor, _C_ops, nn +from paddle.distributed import fleet +from paddle.incubate.nn.functional import int_bincount +from paddle.nn.clip import _squared_l2_norm +from paddle.utils import unique_name + +from paddleformers.utils.log import logger + +if paddle.device.is_compiled_with_custom_device("npu"): + from .npu_fusion_ops import npu_cal_aux_loss_func as cal_aux_loss +else: + from paddle.incubate.nn.functional import cal_aux_loss + + +def masked_fill(x, mask, value): + """ + Fills elements of the input tensor with a given value where mask is True. + + Args: + x (Tensor): Input tensor to be modified + mask (Tensor): Boolean mask tensor (same shape as x) + value (float|int): Value to fill masked elements with + + Returns: + Tensor: New tensor with masked elements replaced by value + """ + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + +@paddle.no_grad() +def compute_optimal_transport(M, r, c, lam=1.0, epsilon=1e-8, max_iters: int = 10): + """ + Computes optimal transport matrix and Sinkhorn distance using Sinkhorn-Knopp algorithm. + + Args: + M (Tensor): Cost matrix (n x m) + r (Tensor): Source marginals (n,) + c (Tensor): Target marginals (m,) + lam (float): Entropic regularization strength + epsilon (float): Convergence threshold + max_iters (int): Maximum iterations + + Returns: + tuple: (optimal transport matrix, Sinkhorn distance) + """ + n, _ = M.shape + P = F.softmax(-M / lam) + u = paddle.zeros(n, "float32") + # normalize this matrix + for _ in range(max_iters): + if (u - P.sum(1)).abs().max() < epsilon: + break + u = P.sum(1) + P *= (r / (u + 1e-8)).reshape((-1, 1)) + P *= (c / (P.sum(0) + 1e-8)).reshape((1, -1)) + P = paddle.where(~P.isnan(), P, paddle.zeros_like(P)) + return P, _ + + +def cast_if_needed(x, dtype): + """ + Casts tensor to specified dtype if not already in that dtype. + + Args: + x (Tensor): Input tensor + dtype: Target dtype + + Returns: + Tensor: Casted tensor + """ + return x.cast(dtype) if x.dtype != dtype else x + + +class FusedGateDetachMatmul(paddle.autograd.PyLayer): + """ + Custom autograd function for fused gate-detached matrix multiplication. + Optimizes forward/backward passes for MoE routing computations. + """ + + @staticmethod + def forward(ctx, x, w): + """ + Forward pass for fused matmul operation. + + Args: + ctx: Context object + x (Tensor): Input tensor + w (Tensor): Weight matrix + + Returns: + Tensor: Result of matrix multiplication + """ + ctx.dtype = paddle.float32 + ctx.save_for_backward(x, w) + return F.linear(cast_if_needed(x, ctx.dtype), cast_if_needed(w, ctx.dtype)) + + @staticmethod + def backward(ctx, y_grad): + """ + Backward pass for gradient computation. + + Args: + ctx: Context object + y_grad (Tensor): Gradient from upstream + + Returns: + tuple: Gradients with respect to inputs + """ + x, w = ctx.saved_tensor() + assert ctx.dtype == y_grad.dtype, "dtype not match" + x_g, w_g = _C_ops.matmul_grad(cast_if_needed(x, ctx.dtype), cast_if_needed(w, ctx.dtype), y_grad, False, False) + + # Especially fix for lora training. + if w.stop_gradient: + return cast_if_needed(x_g, x.dtype), None + return cast_if_needed(x_g, x.dtype), cast_if_needed(w_g, w.dtype) + + +def gate_detach_matmul(x, weight, use_fuse): + """ + Performs gate-detached matrix multiplication with optimization options. + + Args: + x (Tensor): Input tensor + weight (Tensor): Weight matrix + use_fuse (bool): Whether to use fused implementation + + Returns: + Tensor: Result of matrix multiplication + """ + if use_fuse: + return FusedGateDetachMatmul.apply(x, weight) + else: + x = cast_if_needed(x, paddle.float32) + return F.linear(x, weight) + + +class TopKGate(nn.Layer): + """ + Fused version of TopK gate for improved performance. + """ + + def __init__(self, config, layer_idx: int, group, gate_weight=None) -> None: + """ + Initialize the MoE (Mixture of Experts) layer. + + Args: + config: Model configuration containing MoE parameters + layer_idx: Index of this layer in the model + group: Distributed communication group + gate_weight: Optional pre-existing gate weight tensor + """ + super().__init__() + self.config = config + + self.fuse_gate_detach_matmul = config.fuse_gate_detach_matmul + + self.model_dim = config.hidden_size + self.num_experts = config.moe_num_experts + self.num_experts_tensor = sum(config.moe_num_experts) if config.multimodel_experts else config.moe_num_experts + + self.cap = config.moe_capacity + self.group = group + + self.layer_idx = layer_idx + self.global_aux_loss = config.global_aux_loss + if self.global_aux_loss: + self.rank = dist.get_rank(self.group) + + self.sinkhorn_2gate = config.sinkhorn_2gate + self.sinkhorn_temp = config.sinkhorn_temp + self.use_correction_bias = config.moe_use_aux_free # true + self.use_token_type_bias = config.get("moe_use_token_type_bias", False) + + if config.moe_gate_act == "softmax": + self.act = partial(F.softmax, axis=-1) # [S,E] + elif config.moe_gate_act == "sigmoid": + self.act = F.sigmoid + else: + raise ValueError(f"{config.moe_gate_act} is not supported.") + self.no_jitter = True + self.expert_drop = False + self.eye_matrix = None + self.eye_matrix_size = None + self.norm_gate_logits = config.moe_norm_gate_logits # true + self.one = paddle.ones([], dtype="float32") + + self.moe_aux_loss_lambda = paddle.to_tensor(config.moe_aux_loss_lambda, dtype="float32") + self.moe_z_loss_lambda = paddle.to_tensor(config.moe_z_loss_lambda, dtype="float32") + self.moe_orthogonal_loss_lambda = paddle.to_tensor(config.moe_orthogonal_loss_lambda, dtype="float32") + if self.moe_aux_loss_lambda.ndim == 0: + self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.unsqueeze(0) + if self.moe_z_loss_lambda.ndim == 0: + self.moe_z_loss_lambda = self.moe_z_loss_lambda.unsqueeze(0) + if self.moe_orthogonal_loss_lambda.ndim == 0: + self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.unsqueeze(0) + + self.experts_type_ids = None + if config.moe_orthogonal_loss_lambda: + if hasattr(fleet.fleet, "_user_defined_strategy"): + strategy = fleet.fleet._user_defined_strategy + sharding_configs = strategy.hybrid_configs["sharding_configs"] + pp_config = strategy.hybrid_configs["pp_configs"] + assert ( + not sharding_configs.comm_overlap and not pp_config.sharding_comm_overlap + ), "orthogonal loss will cause twice gradient accumulate, will break pp/sharding overlap" + + self.eps = paddle.to_tensor([1e-12], dtype="float32") + if config.multimodel_experts: + if config.get("moe_use_hard_gate", False): + self.num_experts_list = [] + self.experts_type_mask = [] + # hard-gate + group_experts 需要对gate_logits不同部分分开计算 + experts_ids = paddle.zeros([sum(self.num_experts)], dtype="int64").reshape([config.moe_world_size, -1]) + offset = 0 + for i, expert_num in enumerate(self.num_experts): + experts_ids[:, offset : offset + expert_num // config.moe_world_size] = i + offset += expert_num // config.moe_world_size + self.experts_type_ids = experts_ids.reshape([-1]) + logger.info(f"use moe_use_hard_gate, experts_ids: {self.experts_type_ids}") + for i, expert_num in enumerate(self.num_experts): + self.experts_type_mask.append( + self.experts_type_ids == i, + ) + self.num_experts_list.append(expert_num) + else: + # 非group_experts, 依赖token_type_bias实现hard-gate能力。 + assert not config.moe_group_experts, "group_experts must use hard_gate when multimodel_experts is True" + else: + self.num_experts_list = [self.num_experts] + if gate_weight is not None: + self.weight = gate_weight + assert ( + not self.config.moe_use_token_type_bias + ), "gate_weights is from outside, token_type_bias can't be used" + logger.info("moe use gate_weight from outside") + # use fp32 pecison in amp + self._cast_to_low_precision = False + self._cast_to_low_precison = False + else: + self._create_gate_parameter() + logger.info( + f"{config.moe_gate}: w/ capacity: {self.cap} experts:{self.num_experts} " + f"use_token_type_bias:{self.use_token_type_bias} " + f"gate_act:{config.moe_gate_act} " + f"norm_gate_logits={self.norm_gate_logits} use_correction_bias={self.use_correction_bias}" + ) + + def _create_gate_parameter(self): + """ + Create gate weight parameter. + """ + if self.config.multimodel_experts: + # support setting lambda for each expert group + self.moe_z_loss_lambda = self.moe_z_loss_lambda.expand(len(self.num_experts)) + self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.expand(len(self.num_experts)) + self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.expand(len(self.num_experts)) + + for i, num_experts in enumerate(self.num_experts): + if i == 1: + with paddle.utils.unique_name.guard(f"mm_gate_{self.layer_idx}_"): + p = self.create_parameter( + shape=[self.model_dim, num_experts], + dtype="float32", + attr=paddle.ParamAttr(name=unique_name.generate("moe_gate")), + ) + else: + p = self.create_parameter( + shape=[self.model_dim, num_experts], + dtype="float32", + attr=paddle.ParamAttr(name=unique_name.generate("moe_gate")), + ) + p.expert_type = f"expert_type_{i}" + self.add_parameter( + "weight" if i == 0 else f"weight_{i}", # 为了对齐原 state-dict,第一个 gate-weight 不改名. + p, + ) + else: + self.weight = self.create_parameter( + shape=[self.model_dim, self.num_experts], + dtype="float32", + attr=paddle.ParamAttr(name=unique_name.generate("moe_gate")), # for resume dense-ckpt + ) + # use fp32 pecison in amp + self._cast_to_low_precision = False + self._cast_to_low_precison = False + + def get_gate_weight(self, transform_weight): + """ + 在`multimodel_experts` 的情况下,将多个 weights merge 成一个整体 + transform_weight: bool, 按照 local-expert id 将 多模态 weight 交叠 + """ + if not self.config.multimodel_experts: + return self.weight + if not transform_weight: + return paddle.concat( + [getattr(self, "weight" if i == 0 else f"weight_{i}") for i in range(len(self.num_experts))], -1 + ) + weight = paddle.zeros( + [ + self.model_dim, + self.config.moe_world_size, + sum(self.num_experts) // self.config.moe_world_size, + ], + dtype="float32", + ) + offset = 0 + for i, num_experts in enumerate(self.num_experts): + weight[:, :, offset : offset + num_experts // self.config.moe_world_size] = getattr( + self, "weight" if i == 0 else f"weight_{i}" + ).reshape([self.model_dim, self.config.moe_world_size, -1]) + offset += num_experts // self.config.moe_world_size + weight = weight.reshape([self.model_dim, -1]) + return weight + + def forward( + self, + input: Tensor, + token_type_ids: Tensor = None, + transform_weight: bool = True, + ) -> Tuple[Tensor, Tensor, Tensor]: + """ + Forward pass for fused gate. + + Args: + input: Input tensor + token_type_ids: Token type IDs + transform_weight: Whether to transform weights + + Returns: + tuple: (logits, capacity, router_loss) + """ + capacity = self.get_capacity(input.shape[0]) + weight = self.get_gate_weight(transform_weight) + with paddle.amp.auto_cast(False): + logits = gate_detach_matmul(input, weight, self.fuse_gate_detach_matmul) + if self.use_token_type_bias: + assert token_type_ids is not None + assert ( + token_type_ids.max() < self.bias.shape[0] + ), f"token_type_ids {token_type_ids.max()} >= bias shape {self.bias.shape[0]}" + bias = self.bias[token_type_ids] # [seq] + logits = logits + bias + + router_loss = paddle.zeros([1], dtype="float32") + router_loss.stop_gradient = False + + return logits, capacity, router_loss + + def get_capacity(self, num_tokens, cap_factor=None): + """ + Calculate capacity based on number of tokens. + + Args: + num_tokens: Number of input tokens + cap_factor: Optional capacity factor override + + Returns: + int: Calculated capacity + """ + num_experts = sum(self.num_experts) if self.config.multimodel_experts else self.num_experts + if cap_factor is not None: + cap = cap_factor + else: + if self.training: + cap = self.cap[0] + elif num_tokens < num_experts: # seqlen < num_expert + cap = self.cap[2] + else: + cap = self.cap[1] + # capacity = 2S/E + capacity = int(cap * num_tokens // num_experts) + assert capacity > 0, f"requires capacity to >= 0. cap={cap}, num_tokens={num_tokens}" + return capacity + + def _cal_aux_loss( + self, gate_prob, dispatch_mask, num_experts=None, use_group=None, tokens_mask=None, dispatch_tokens_mask=None + ): + """ + Calculate auxiliary loss for router. + + Args: + gate_prob: Gate probabilities tensor + dispatch_mask: Dispatch mask tensor + num_experts: Number of experts + use_group: Whether to use expert groups + tokens_mask: Tokens mask + dispatch_tokens_mask: Dispatch tokens mask + + Returns: + Tensor: Calculated auxiliary loss + """ + if self.act is F.sigmoid: + gate_prob = gate_prob / gate_prob.sum(-1, keepdim=True) + + if self.use_correction_bias: + if tokens_mask is not None: + gate_prob_this_modality = gate_prob[tokens_mask.astype("bool")] + if gate_prob_this_modality.shape[0]: + _, top_idx = gate_prob_this_modality.topk(k=self.config.moe_k, axis=-1) + dispatch_mask = int_bincount(top_idx.reshape([-1]), 0, gate_prob.shape[-1], paddle.int64) + else: + dispatch_mask = paddle.zeros(gate_prob.shape[-1], dtype="int64") + dist.stream.all_reduce( + dispatch_mask, + group=self.group, + use_calc_stream=True, + ) + else: + _, top_idx = gate_prob.topk(k=self.config.moe_k, axis=-1) + dispatch_mask = int_bincount(top_idx.reshape([-1]), 0, gate_prob.shape[-1], paddle.int64) + if num_experts is None: + num_experts = self.num_experts_tensor + if use_group is None: + use_group = self.config.moe_group_experts + + if ( + (tokens_mask is None or len(tokens_mask.shape) == 1) + and (tokens_mask is None or tokens_mask.shape[0] == gate_prob.shape[0]) + and gate_prob.shape[0] >= gate_prob.shape[1] + ): + if tokens_mask is not None and tokens_mask.dtype != gate_prob.dtype: + tokens_mask = tokens_mask.astype(gate_prob.dtype) + l_aux, seqlen_float, ce = cal_aux_loss( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + self.config.moe_k, + clip_min=1e-6, + ) + return l_aux + + if tokens_mask is not None and tokens_mask.dtype != gate_prob.dtype: + tokens_mask = tokens_mask.astype(gate_prob.dtype) + + scale = None + if dispatch_tokens_mask is not None: + seqlen_float = dispatch_tokens_mask.astype(gate_prob.dtype).sum() + if tokens_mask is not None and gate_prob.shape[0] != dispatch_tokens_mask.shape[0]: + scale = seqlen_float / paddle.clip(tokens_mask.sum(), min=1e-6) + elif tokens_mask is not None: + seqlen_float = tokens_mask.sum() + else: + seqlen_float = gate_prob.numel().astype(gate_prob.dtype) / num_experts + seqlen_float = paddle.clip(seqlen_float, min=1e-6) + + if len(dispatch_mask.shape) == 2: + dispatch_mask = dispatch_mask.sum(0) + ce = dispatch_mask.astype(gate_prob.dtype).detach() / seqlen_float + me = paddle.sum(gate_prob, axis=0) / seqlen_float + if self.global_aux_loss: + me_list, ce_list = [], [] + dist.all_gather(me_list, me, group=self.group) + dist.all_gather(ce_list, ce, group=self.group) + + me_list[self.rank] = me + ce_list[self.rank] = ce + me = paddle.stack(me_list).mean(0) + ce = paddle.stack(ce_list).mean(0) + l_aux = paddle.sum(me * ce) * num_experts + if use_group: + l_aux = l_aux / self.config.moe_k + + if scale is not None: + # forward local me, backward global me + l_aux = l_aux + (scale - self.one) * l_aux.detach() + + return l_aux + + def _cal_z_loss(self, logits, loss_mask=None): + """ + Calculate Z-loss for router. + + Args: + logits: Input logits tensor + loss_mask: Optional loss mask + + Returns: + Tensor: Calculated Z-loss + """ + + if loss_mask is not None: + loss_mask = loss_mask.astype(logits.dtype) + l_zloss = (logits.logsumexp(1).square() * loss_mask).sum() / paddle.clip(loss_mask.sum(), min=1e-6) + else: + l_zloss = logits.logsumexp(1).square().mean() + # TODO group_experts 分group计算zloss + return l_zloss + + def _cal_orthogonal_loss_opt_each_weight(self, weight, use_group): + """ + Calculate optimized orthogonal loss for each weight. + + Args: + weight: Weight tensor + use_group: Whether to use expert groups + + Returns: + Tensor: Calculated orthogonal loss + """ + if weight.dtype != paddle.float32: + weight = weight.astype(paddle.float32) + + weight = weight.transpose([1, 0]).contiguous() # transpose weight here + wnorm = weight.norm(axis=1) + weight = weight / paddle.maximum(wnorm, self.eps).unsqueeze(1) + + if use_group: + weight = weight.reshape([self.config.moe_k, -1, weight.shape[1]]) # [K, E/K, H] + eye_matrix = paddle.eye(weight.shape[1], dtype=weight.dtype).unsqueeze(0) + else: + eye_matrix = paddle.eye(weight.shape[0], dtype=weight.dtype) + + weight_matmul = paddle.matmul(weight, weight, transpose_y=True) + + orthogonal_loss = weight_matmul - eye_matrix + orthogonal_loss = _squared_l2_norm(orthogonal_loss) / orthogonal_loss.size + return orthogonal_loss + + def _cal_orthogonal_loss(self, weight_id=None, use_group=None): + """ + Calculate orthogonal loss for router weights. + + Args: + weight_id: Optional weight ID + use_group: Whether to use expert groups + + Returns: + Tensor: Calculated orthogonal loss + """ + if use_group is None: + use_group = self.config.moe_group_experts and self.config.moe_group_orthogonal_loss + + if weight_id is not None: + if weight_id == 0: + w_ = self.weight + else: + assert self.config.multimodel_experts + w_ = getattr(self, f"weight_{weight_id}") + return self._cal_orthogonal_loss_opt_each_weight(w_, use_group) + + orthogonal_loss = self._cal_orthogonal_loss_opt_each_weight(self.weight, use_group) + if self.config.multimodel_experts: + for i in range(1, len(self.config.moe_num_experts)): + w_ = getattr(self, f"weight_{i}") + orthogonal_loss += self._cal_orthogonal_loss_opt_each_weight(w_, use_group=False) + return orthogonal_loss diff --git a/paddleformers/nn/moe/utils.py b/paddleformers/nn/moe/utils.py new file mode 100644 index 00000000000..488e1d05a00 --- /dev/null +++ b/paddleformers/nn/moe/utils.py @@ -0,0 +1,426 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Common distributed utils. +""" + +from typing import Any, Callable, List, Optional + +import paddle +from paddle import distributed as dist +from paddle import framework +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.incubate.tensor.manipulation import create_async_load + + +def get_hcg(): + """ + Get hybrid communicate group. + """ + return fleet.get_hybrid_communicate_group() + + +def scatter_axis(input, group=None, axis=0): + """ + Uniformly splits the `input` along dimension 0 across model parallel groups. + This API is not related to `distributed.scatter`. + + Args: + input: Input tensor to be split + group: Communication group for parallel processing (default: model parallel group) + axis: Dimension along which to split (default: 0) + + Returns: + A slice of the input tensor corresponding to this rank's portion + """ + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + if parallelism == 1: + return input.clone() + rank = group.rank + seq_len = input.shape[axis] + assert seq_len % parallelism == 0, ( + f"Input sequence length {seq_len} can't be divided exactly" f" by sequence parallelism {parallelism}" + ) + interval = seq_len // parallelism + input = paddle.slice(input, axes=[axis], starts=[interval * rank], ends=[interval * (rank + 1)]) + # slice uses stride, so we maintain the memory of whole input, use assign to free the whole input + # which can avoid OOM. + input = paddle.assign(input) + return input + + +class ReduceScatterGroupOp(PyLayer): + """ + Perform group reduce scatter. + """ + + @staticmethod + def forward(ctx, input, group=None): + """Forward pass: Reduce-Scatter operation + Args: + input (Tensor): Input tensor with shape [s, b, h]. + The 's' dimension will be split across model parallel group. + group (ProcessGroup): Model parallel process group, + uses global group by default. + Returns: + Tensor: Output tensor after Reduce-Scatter with shape [s/n, b, h], + each device holds partial data of the original input. + """ + ctx.group = group + return reduce_scatter_group(input, group=group) + + @staticmethod + def backward(ctx, grad): + """Backward pass: All-Gather operation + Args: + grad (Tensor): Upstream gradient with shape [s/n, b, h] + Returns: + Tensor: Full gradient after All-Gather with restored shape [s, b, h], + aggregating gradients from all devices in model parallel group. + """ + return all_gather_group(grad, group=ctx.group) + + +class AllGatherGroupOp(PyLayer): + """ + Perform group allgather. + """ + + @staticmethod + def forward(ctx, input, group=None): + """Forward pass: All-Gather operation + Args: + input (Tensor): Partitioned tensor with shape [s/n, b, h] + The 's' dimension is distributed across devices + group (ProcessGroup): Model parallel process group, + uses global group by default + Returns: + Tensor: Assembled tensor after All-Gather with shape [s, b, h], + containing full parameter from all devices + """ + ctx.group = group + return all_gather_group(input, group=group) + + @staticmethod + def backward(ctx, grad): + """Backward pass: Reduce-Scatter operation + Args: + grad (Tensor): Full gradient tensor with shape [s, b, h] + Returns: + Tensor: Scattered gradient with shape [s/n, b, h], + distributing reduced gradients to each device + """ + return reduce_scatter_group(grad, group=ctx.group) + + +def get_async_loader(): + """get_async_loader""" + global async_loader + if not hasattr(fleet.fleet, "_hcg"): + if async_loader is None: + async_loader = create_async_load() + return async_loader + + hcg = get_hcg() + if not hasattr(hcg, "async_loader"): + hcg.async_loader = create_async_load() + return hcg.async_loader + + +def hack_offload_wait(task): + """hack_offload_wait""" + task.cpu_wait() + + +def all_gather_group(input, group=None, axis=0): + """Perform collective all-gather operation across a process group with axis control. + + Functional Behavior: + - Aggregates input tensors from all processes in the specified group + - Supports concatenation along arbitrary dimensions (axis parameter) + - Optimizes for axis=0 via direct shape expansion to avoid concatenation overhead + + Args: + input (Tensor): Local tensor to be gathered (shape: [..., D, ...]) + group (ProcessGroup): Communication group (defaults to model parallel group) + axis (int): Concatenation dimension (default=0) + + Returns: + Tensor: Concatenated tensor combining inputs from all processes: + - When axis=0: shape [D*N, ...] (N = group size) + - Otherwise: shape [..., D*N, ...] along specified axis + """ + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + if parallelism == 1: + return input.clone() + output_shape = input.shape + if axis == 0: + output_shape[axis] = output_shape[axis] * parallelism + output = paddle.empty(shape=output_shape, dtype=input.dtype) + dist.stream.all_gather(output, input, group=group, use_calc_stream=True) + return output + outputs = [paddle.empty(output_shape, dtype=input.dtype) for _ in range(parallelism)] + dist.stream.all_gather(outputs, input, group=group, use_calc_stream=True) + output = paddle.concat(outputs, axis=axis) + return output + + +def reduce_scatter_group(input, group=None): + """Perform reduce-scatter collective operation across a process group. + + Functional Behavior: + - Aggregates (sums) input tensors across all processes in the group + - Scatters the reduced result equally to all participants + - Operates along the first dimension (axis=0) of the input tensor + + Args: + input (Tensor): Local tensor to reduce (shape: [N*K, ...] where N=group_size) + group (ProcessGroup): Communication group (defaults to model parallel group) + + Returns: + Tensor: Scattered portion of reduced tensor with shape [K, ...] + """ + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + if parallelism == 1: + return input.clone() + output_shape = input.shape + assert ( + input.shape[0] % parallelism == 0 + ), f"Input sequence length {input.shape[0]} can't be divided exactly by sequence parallelism {parallelism}" + output_shape[0] = output_shape[0] // parallelism + output = paddle.empty(shape=output_shape, dtype=input.dtype) + dist.stream.reduce_scatter(output, input, op=dist.ReduceOp.SUM, group=group, use_calc_stream=True) + return output + + +class ScatterOp(PyLayer): + """ + Each rank slices its own portion from the **same** sequence (uniformly split). + During backward pass, gradients from all ranks are aggregated to restore + the mp (model parallelism) synchronization state. + The inverse operation is `GatherOp`. + + input: Tensor [S,*] + + Note: Not related to `distributed.scatter`. + """ + + @staticmethod + def forward(ctx, input, axis=0, group=None): + """forward""" + ctx.axis = axis + ctx.group = group + return scatter_axis(input, axis=axis, group=ctx.group) + + @staticmethod + def backward(ctx, grad): + """backward""" + return all_gather_group(grad, axis=ctx.axis, group=ctx.group) + + +def detach_and_requires_grad_(*args): + """ + Detach tensors while preserving their requires_grad status. + + Args: + args: Input tensors + + Returns: + list: Detached tensors + """ + ret = [a.detach() if a is not None else None for a in args] + for r, a in zip(ret, args): + if a is not None: + r.stop_gradient = a.stop_gradient + return ret + + +class FakeClone(paddle.autograd.PyLayer): + """ + Fake clone operation that preserves computation graph without data copy. + """ + + @staticmethod + def forward(ctx, input): + """ + Create fake clone of input tensor. + + Args: + input: Input tensor + + Returns: + Tensor: Fake cloned tensor + """ + if input.is_contiguous(): + fake_output = paddle.empty_like(input) + input._share_buffer_to(fake_output) + else: + fake_output = input.clone() + return fake_output + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass for fake clone. + + Args: + grad_output: Gradient of output + + Returns: + Tensor: Gradient of input + """ + return grad_output + + +def manual_backward(f: Callable, is_first_fwd: bool, *args: List[Any]): + """ + Perform manual backward pass with gradient tracing control. + + Args: + f: Function to execute + is_first_fwd: Whether this is the first forward pass + args: Arguments for the function + + Returns: + tuple: (backward function, function outputs) + """ + tracer = framework._dygraph_tracer() + orig = tracer._has_grad + if not is_first_fwd: + tracer._has_grad = True # turn on grad trace so we can manual backward + + detached_args = detach_and_requires_grad_(*args) + detached_args_clone = [FakeClone.apply(a) if a is not None else None for a in detached_args] + out = f(*detached_args_clone) + if isinstance(out, list): + out = tuple(out) + elif not isinstance(out, tuple): + out = (out,) + + if is_first_fwd: + tracer._has_grad = orig + return None, out + + out_cached = [FakeClone.apply(o) for o in out if o is not None] # do not cache stop_gradient output + + for o in out_cached: + o._clear_dataptr() # free mem + tracer._has_grad = orig + + def bwd_f(*grad): + nonlocal out_cached, detached_args, f + grad = list(grad) + grad = [g for g in grad if g is not None] + assert grad and out_cached, (len(grad), len(out_cached)) + # out 中的 stop_graident 参数,也会收到 gradient,在这里过滤掉 + grad, out_cached = zip(*[(g, o) for g, o in zip(grad, out_cached) if not o.stop_gradient]) + + assert len(grad) == len(out_cached), (len(grad), len(out_cached), f) + # out, grad = zip(*[(o, g) for o, g in zip(out, grad) if g is not None]) + paddle.autograd.backward(out_cached, grad) + return tuple([t.grad for t in detached_args if t is not None]) + + return bwd_f, out + +def permute( + tokens, + routing_map, + num_out_tokens: Optional[int] = None, + drop_and_pad: bool = False, +): + """Permute the tokens and probs based on the mask. + Tokens with the same designated expert will be grouped together. + The shape of mask is [tokens, num_experts], it indicates which experts were selected + by each token. + + Args: + tokens (paddle.Tensor): The input token tensor, [num_tokens, hidden]. + routing_map (paddle.Tensor): The sparse token to expert mapping, [num_tokens, num_experts]. + num_out_tokens (int, optional): The number of output tokens. If None, it's set to + the number of input tokens. + drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop + and pads the number of tokens to the expert capacity. + """ + assert not drop_and_pad, "token-drop and pads is not supported" + num_tokens, hidden = tokens.shape + num_experts = routing_map.shape[1] + + # mask [num_tokens, num_experts] -> [num_experts, num_tokens] + routing_map = routing_map.cast(paddle.bool).T.contiguous() + + # Create a dense expert-to-token mapping from the sparse token-to-expert mapping + token_indices = paddle.arange(num_tokens).unsqueeze(0).expand([num_experts, -1]) + sorted_indices = token_indices.masked_select(routing_map) + + # use the mapping to permute the tokens + permuted_input = tokens.index_select(axis=0, index=sorted_indices) + + return permuted_input, sorted_indices + + +def unpermute( + permuted_tokens: paddle.Tensor, + sorted_indices: paddle.Tensor, + restore_shape: paddle.shape, + probs: paddle.Tensor = None, + routing_map: paddle.Tensor = None, + drop_and_pad: bool = False, +): + """ + Restore the original order of tokens after permutation. If probs are provided, it + will also apply them to the tokens before restoring the order. + + Args: + permuted_tokens (paddle.Tensor): The permuted token tensor. + sorted_indices (paddle.Tensor): The indices used to sort the tokens. + restore_shape (paddle.shape): The shape of the unpermuted tensor. + probs (paddle.Tensor, optional): The unpermuted probs tensor, + routing_map (paddle.Tensor, optional): Token to expert mapping, shape + [num_tokens, num_experts]. + drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop + and pads the number of tokens to the expert capacity. + + Returns: + paddle.Tensor: The tokens restored to their original order. + """ + assert not drop_and_pad, "token-drop and pads is not supported" + _, hidden = restore_shape + + if probs is not None: + assert routing_map is not None, "Mask must be provided to permute the probs." + permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous()) + permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1) + + # Create an output tensor filled with zeros + output_tokens = paddle.zeros(restore_shape, dtype=permuted_tokens.dtype) + # Scatter add the permuted_input back to the original positions + output_tokens.put_along_axis_( + axis=0, + indices=sorted_indices.unsqueeze(1).expand([-1, hidden]), + values=permuted_tokens, + reduce="add", + include_self=True, + ) + return output_tokens \ No newline at end of file diff --git a/paddleformers/transformers/qwen2_moe/modeling.py b/paddleformers/transformers/qwen2_moe/modeling.py index e23237609b0..e527b1dd1cb 100644 --- a/paddleformers/transformers/qwen2_moe/modeling.py +++ b/paddleformers/transformers/qwen2_moe/modeling.py @@ -43,6 +43,8 @@ from ..utils import logger from .configuration import Qwen2MoeConfig +from paddleformers.nn.moe.moe_block import create_moe_block + try: from paddle.incubate.nn.functional import fused_rotary_position_embedding except ImportError: @@ -821,7 +823,10 @@ def __init__(self, config: Qwen2MoeConfig, layerwise_recompute: bool = False): self.self_attn = Qwen2MoeAttention(config, layerwise_recompute) if config.num_experts > 0: - self.mlp = Qwen2MoeSparseMoeBlock(config) + self.mlp = create_moe_block(config=config, + expert_class=Qwen2MoeMLP, + use_shared_expert=True, + moe_mode="qwen") else: # num_experts == 0 or this layer is not sparse layer self.mlp = Qwen2MoeMLP(config) diff --git a/paddleformers/transformers/qwen3_moe/modeling.py b/paddleformers/transformers/qwen3_moe/modeling.py index 56a7a3bd212..6bf17bcc061 100644 --- a/paddleformers/transformers/qwen3_moe/modeling.py +++ b/paddleformers/transformers/qwen3_moe/modeling.py @@ -35,6 +35,7 @@ from ..moe_layer import MoELayer from ..utils import logger from .configuration import Qwen3MoeConfig +from paddleformers.nn.moe.moe_block import create_moe_block try: from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp @@ -166,7 +167,10 @@ def __init__(self, config: Qwen3MoeConfig, layerwise_recompute: bool = False): self.self_attn = Qwen3MoeAttention(config, layerwise_recompute) if config.num_experts > 0: - self.mlp = ExpertParallelQwen3MoeSparseMoeBlock(config) + self.mlp = create_moe_block(config=config, + expert_class=Qwen3MoeMLP, + use_shared_expert=False, + moe_mode="qwen") else: # num_experts == 0 or this layer is not sparse layer self.mlp = Qwen3MoeMLP(config)