From ef09eec4465e89f5fae3b2ddebd91a89b0cf1cbc Mon Sep 17 00:00:00 2001 From: blacksheep-Aristotle Date: Thu, 21 Aug 2025 10:59:58 +0800 Subject: [PATCH] moe_layer support fine_grained_forward --- .../transformers/deepseek_v2/configuration.py | 5 +- paddlenlp/transformers/moe_layer.py | 122 +++++++++++++++++- paddlenlp/transformers/moe_utils.py | 82 ++++++++++++ 3 files changed, 203 insertions(+), 6 deletions(-) diff --git a/paddlenlp/transformers/deepseek_v2/configuration.py b/paddlenlp/transformers/deepseek_v2/configuration.py index d21afc20780f..64b0d8a1b615 100644 --- a/paddlenlp/transformers/deepseek_v2/configuration.py +++ b/paddlenlp/transformers/deepseek_v2/configuration.py @@ -179,6 +179,8 @@ def __init__( attention_dropout=0.0, speculate_model_type=False, using_flex_token=False, + deepep_fine_grained=False, + deepep_tokens_per_subbatch=1024, **kwargs, ): self.vocab_size = vocab_size @@ -227,7 +229,8 @@ def __init__( self.speculate_model_type = speculate_model_type self.use_fp8 = False self.using_flex_token = using_flex_token - + self.deepep_fine_grained = deepep_fine_grained + self.deepep_tokens_per_subbatch = deepep_tokens_per_subbatch super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index 040fb5d1f22a..4f274181254f 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -24,7 +24,13 @@ from paddle import Tensor, nn from paddle.distributed.communication.group import Group +try: + from paddle import scatter_add_ +except ImportError: + scatter_add_ = None + from .moe_gate import PretrainedMoEGate +from .moe_utils import FakeGather, topk_to_permuted_indices_single from .token_dispatcher import MoEFlexTokenDispatcher @@ -366,13 +372,119 @@ def expert_forward(self, dispatched_input, tokens_per_expert): return paddle.concat(outputs, axis=0) + def maybe_split_subbatch_data(self, permuted_tokens, token_permuted_indices, prob_permuted_indices): + """maybe_split_subbatch_data""" + + def split_subbatch_data(data, tokens_per_subbatch): + total_token_num = data.shape[0] + + full_batch_num, remainder = divmod(total_token_num, tokens_per_subbatch) + num_or_sections = [tokens_per_subbatch] * full_batch_num + if remainder: + num_or_sections.append(remainder) + + assert ( + sum(num_or_sections) == total_token_num + ), f"get_subbatch_data fail, {sum(num_or_sections)}, {total_token_num}" + # when data is 0-size tensor, we need to compute it and construct the right backward graph. + if total_token_num == 0: + return [data] + return paddle.split(data, num_or_sections=num_or_sections, axis=0) + + if self.config.deepep_tokens_per_subbatch > 0: + assert ( + permuted_tokens.shape[0] == token_permuted_indices.shape[0] + ), f"Shape mismatch between {permuted_tokens.shape[0]} and {token_permuted_indices.shape[0]}" + assert ( + permuted_tokens.shape[0] == prob_permuted_indices.shape[0] + ), f"Shape mismatch between {permuted_tokens.shape[0]} and {prob_permuted_indices.shape[0]}" + permuted_tokens_list = split_subbatch_data(permuted_tokens, self.config.deepep_tokens_per_subbatch) + token_permuted_indices_list = split_subbatch_data( + token_permuted_indices, self.config.deepep_tokens_per_subbatch + ) + prob_permuted_indices_list = split_subbatch_data( + prob_permuted_indices, self.config.deepep_tokens_per_subbatch + ) + else: + permuted_tokens_list = [permuted_tokens] + token_permuted_indices_list = [token_permuted_indices] + prob_permuted_indices_list = [prob_permuted_indices] + return permuted_tokens_list, token_permuted_indices_list, prob_permuted_indices_list + + def fine_grained_forward_experts(self, dispatched_input, dispatched_probs, dispatched_indices, dispatch_topk): + """fine_grained_forward_experts""" + print("moe layer input shape ", self.hidden_shape, " moe layer dispatch output shape ", dispatched_input.shape) + output_tokens = paddle.zeros(dispatched_input.shape, dispatched_input.dtype) + + tokens_per_expert = self.token_dispatcher._comm_manager.tokens_per_expert + + for expert_id, num_tokens in enumerate(tokens_per_expert): + + token_permuted_indices, prob_permuted_indices = topk_to_permuted_indices_single( + dispatched_indices, num_tokens, expert_id, dispatch_topk + ) + permuted_tokens = FakeGather.apply(dispatched_input, token_permuted_indices) + # If deepep_tokens_per_subbatch > 0, the data is split into multiple subbatches. + ( + permuted_tokens_list, + token_permuted_indices_list, + prob_permuted_indices_list, + ) = self.maybe_split_subbatch_data(permuted_tokens, token_permuted_indices, prob_permuted_indices) + + for permuted_tokens_, token_permuted_indices_, prob_permuted_indices_ in zip( + permuted_tokens_list, token_permuted_indices_list, prob_permuted_indices_list + ): + # ffn + permuted_tokens_ = self.experts[expert_id](permuted_tokens_) + # local unpermute + if dispatched_probs is not None: + permuted_probs = FakeGather.apply(dispatched_probs.flatten(), prob_permuted_indices_) + if permuted_tokens_.dtype != permuted_probs.dtype: + new_permuted_tokens = permuted_tokens_.astype(permuted_probs.dtype) + else: + new_permuted_tokens = permuted_tokens_ + permuted_tokens_ = new_permuted_tokens * permuted_probs.unsqueeze(-1) + if scatter_add_ is not None: + scatter_add_(output_tokens, token_permuted_indices_, permuted_tokens_.astype(output_tokens.dtype)) + else: + output_tokens.scatter_( + index=token_permuted_indices_, + updates=permuted_tokens_.astype(output_tokens.dtype), + overwrite=False, + ) + + dispatched_input._clear_to_zero_allocation() + + return output_tokens + def forward(self, hidden_states: paddle.Tensor): _, _, d_model = hidden_states.shape # reshaped_input = hidden_states.reshape([-1, d_model]) probs, routing_map, l_aux, l_zloss = self.router(hidden_states) - (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation( - hidden_states, probs, routing_map - ) - expert_output = self.expert_forward(dispatched_input, tokens_per_expert) - output, _ = self.token_dispatcher.token_unpermutation(expert_output, None) + + if self.config.deepep_fine_grained: + # global dispatch + # (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation( + # hidden_states, probs, routing_map + # ) + self.hidden_shape = hidden_states.shape + hidden_states = hidden_states.view([-1, self.hidden_shape[-1]]) + + self.token_dispatcher._comm_manager.setup_metadata(routing_map, probs) + dispatched_input = self.token_dispatcher._comm_manager.dispatch(hidden_states) + + dispatched_indices = self.token_dispatcher._comm_manager.dispatched_indices + dispatched_probs = self.token_dispatcher._comm_manager.dispatched_probs + # local dispatch & forward_experts & local combine + output_tokens = self.fine_grained_forward_experts( + dispatched_input, dispatched_probs, dispatched_indices, self.moe_router_topk + ) + # global combine + output = self.token_dispatcher._comm_manager.combine(output_tokens) + else: + (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation( + hidden_states, probs, routing_map + ) + expert_output = self.expert_forward(dispatched_input, tokens_per_expert) + output, _ = self.token_dispatcher.token_unpermutation(expert_output, None) return output, l_aux, l_zloss diff --git a/paddlenlp/transformers/moe_utils.py b/paddlenlp/transformers/moe_utils.py index d82654f6375b..414984a5155e 100644 --- a/paddlenlp/transformers/moe_utils.py +++ b/paddlenlp/transformers/moe_utils.py @@ -101,3 +101,85 @@ def unpermute( else: output_tokens.scatter_(index=sorted_indices, updates=permuted_tokens, overwrite=False) return output_tokens + + +def topk_to_permuted_indices_single(x, num_tokens, expert_id, topk): + """ + Convert the topk indices to permuted indices. + """ + x = paddle.flatten(x) + prob_permuted_indices = paddle.tensor.search._restrict_nonzero(x == expert_id, num_tokens).flatten() + token_permuted_indices = prob_permuted_indices // topk + return token_permuted_indices, prob_permuted_indices + + +def topk_to_permuted_indices(x, num_tokens_per_expert_list, topk): + """ + Convert the topk indices to permuted indices. + """ + x = paddle.flatten(x) + prob_permuted_indices = paddle.concat( + [ + paddle.tensor.search._restrict_nonzero(x == i, total_true_num) + for i, total_true_num in enumerate(num_tokens_per_expert_list) + ] + ).flatten() + token_permuted_indices = prob_permuted_indices // topk + return token_permuted_indices, prob_permuted_indices + + +class FakeClone(paddle.autograd.PyLayer): + """ + manual_backward中, 为了保留局部的计算图做临时反向计算 + 需要把manual_backward的output给clone出来, 这个clone + 本质上不需要output的值, 而是需要拿到output身上的计算图 + + 但调用paddle.clone会做一次额外的数据拷贝, 这是没必要的 + FakeClone可以免去这个数据拷贝, 实现摘取计算图的目的 + """ + + @staticmethod + def forward(ctx, input): + """forward""" + if input.is_contiguous(): + fake_output = paddle.empty_like(input) + input._share_buffer_to(fake_output) + else: + fake_output = input.clone() + return fake_output + + @staticmethod + def backward(ctx, grad_output): + """backward""" + return grad_output + + +class FakeGather(paddle.autograd.PyLayer): + """ + 临时绕开gather 0size索引的coredump问题 + """ + + @staticmethod + def forward(ctx, input, indices): + """forward""" + assert len(indices.shape) == 1 + ctx.save_for_backward(indices) + ctx.input_shape = input.shape + if indices.shape[0] == 0: + out_shape = input.shape + out_shape[0] = 0 + return paddle.zeros(out_shape, dtype=input.dtype) + return paddle.index_select(input, axis=0, index=indices) + + @staticmethod + def backward(ctx, grad_output): + """backward""" + indices = ctx.saved_tensor() + input_shape = ctx.input_shape + grad_input = paddle.zeros(input_shape, dtype=grad_output.dtype) + if indices.shape[0] != 0: + if scatter_add_ is not None: + scatter_add_(grad_input, indices.unsqueeze(-1), grad_output) + else: + paddle.scatter_(grad_input, indices.unsqueeze(-1), grad_output, overwrite=False) + return grad_input, None