diff --git a/paddleformers/trainer/__init__.py b/paddleformers/trainer/__init__.py index 53ceb66a961..b129b4a20a1 100644 --- a/paddleformers/trainer/__init__.py +++ b/paddleformers/trainer/__init__.py @@ -75,6 +75,8 @@ "TrainerState", "DEFAULT_PROGRESS_CALLBACK", "TrainerCallback", + "StepFlexToken", + "FP8QuantWeightCallback", ], "trainer_utils": [ "get_last_checkpoint", diff --git a/paddleformers/trainer/trainer.py b/paddleformers/trainer/trainer.py index 8dd4904b4ab..efda89b8cfd 100644 --- a/paddleformers/trainer/trainer.py +++ b/paddleformers/trainer/trainer.py @@ -92,6 +92,10 @@ RowParallelQuantizationLinear, ) +try: + from ..quantization.quantization_linear import QuantizationLinear +except: + QuantizationLinear = None try: from paddle.distributed.fleet.utils.sequence_parallel_utils import ( register_sequence_parallel_allreduce_hooks, @@ -201,6 +205,14 @@ DEFAULT_CALLBACKS = [DefaultFlowCallback] DEFAULT_PROGRESS_CALLBACK = ProgressCallback +# Name of the files used for checkpointing +TRAINING_ARGS_NAME = "training_args.bin" +TRAINER_STATE_NAME = "trainer_state.json" + +SCHEDULER_NAME = "scheduler.pdparams" +SCALER_NAME = "scaler.pdparams" + + if is_datasets_available(): import datasets diff --git a/paddleformers/trainer/trainer_callback.py b/paddleformers/trainer/trainer_callback.py index 812c8dc9f59..d3b50d856b7 100644 --- a/paddleformers/trainer/trainer_callback.py +++ b/paddleformers/trainer/trainer_callback.py @@ -20,12 +20,14 @@ """ import dataclasses import json +import os from dataclasses import dataclass from typing import Dict, List, Optional, Union import numpy as np from tqdm.auto import tqdm +from paddleformers.transformers.moe_utils import offload, reload from ..utils.log import logger from .trainer_utils import IntervalStrategy, has_length from .training_args import TrainingArguments @@ -39,6 +41,8 @@ "ProgressCallback", "PrinterCallback", "EarlyStoppingCallback", + "StepFlexToken", + "FP8QuantWeightCallback", ] @@ -608,3 +612,65 @@ def on_evaluate(self, args, state, control, metrics, **kwargs): self.check_metric_value(args, state, control, metric_value) if self.early_stopping_patience_counter >= self.early_stopping_patience: control.should_training_stop = True + + +class StepFlexToken(TrainerCallback): + def on_step_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + model = kwargs.pop("model") + if hasattr(model, "step_flex_token"): + model.step_flex_token(state.global_step) + + +g_shard_bypass_dygraph_optimizer = int(os.environ.get("FLAGS_shard_bypass_dygraph_optimizer", 0)) + + +def enable_in_dict_config(config, key): + """enable_in_dict_config""" + return key in config and config[key] + + +skip_count = 0 + + +class FP8QuantWeightCallback(TrainerCallback): + """ + FP8QuantWeightCallback + """ + + def on_step_begin(self, args, state, control, **kwargs): + """ + 每个step开始前把专家参数quant成fp8q + """ + model = kwargs["model"] + optimizer = kwargs["optimizer"] + global skip_count + + if not g_shard_bypass_dygraph_optimizer or skip_count == 0: + model.fp8_quant_weight(True) + optimizer.clear_param_storage("moe_expert") + optimizer.clear_param_storage("rms_linear") + optimizer.clear_param_storage("memory_attn") + optimizer.clear_param_storage("attn_out_project") + optimizer.clear_param_storage("shared_expert") + + self.moe_weights_name = [] + for param in optimizer._inner_opt._parameter_list: + color = getattr(param, "color", -1) + if isinstance(color, dict) and color["color"] == "moe_expert": + self.moe_weights_name.append(param.name) + + for name in self.moe_weights_name: + offload(optimizer._master_weights[name]) + + skip_count += 1 + + def on_optimizer_begin(self, args, state, control, **kwargs): + optimizer = kwargs["optimizer"] + for name in self.moe_weights_name: + reload(optimizer._master_weights[name]) diff --git a/paddleformers/trainer/training_args.py b/paddleformers/trainer/training_args.py index 16bce8e4c71..788e501cbc7 100644 --- a/paddleformers/trainer/training_args.py +++ b/paddleformers/trainer/training_args.py @@ -32,6 +32,7 @@ from paddle.distributed import fleet from ..utils.env import PREFIX_CHECKPOINT_DIR +from ..utils.fault_tolerance import is_ft_env from ..utils.log import logger from ..utils.pdc_sdk import FLASH_DEVICE from .trainer_utils import ( @@ -1397,12 +1398,7 @@ def is_segment_parallel_supported(): else: order = ["dp", "sharding", "pp", "mp"] if self.use_expert_parallel: - if self.moe_sharding_parallel_degree >= 1 and self.expert_parallel_degree > 1: - order.insert(-1, "ep") - sd_idx = order.index("sharding") - # if pp_first, the order = ["dp", "pp", "moe_sharding", "sharding", "sep", "ep", "mp"] - # if sharding_first, the order is ["dp", "moe_sharding", "sharding", "pp", "sep", "ep", "mp"] - order.insert(sd_idx, "moe_sharding") + order = order[1:-1] + ["dp", "mp"] if is_segment_parallel_supported(): hybrid_configs = { @@ -1556,6 +1552,9 @@ def is_segment_parallel_supported(): fleet.init(is_collective=True, strategy=strategy) logger.info(strategy) + if self.expert_parallel_degree > 1: + self.add_moe_comm_group() + elif self.enable_auto_parallel: self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1) self.sep_parallel_degree = max(self.sep_parallel_degree, 1) diff --git a/paddleformers/transformers/fp8_utils.py b/paddleformers/transformers/fp8_utils.py new file mode 100644 index 00000000000..93790005d67 --- /dev/null +++ b/paddleformers/transformers/fp8_utils.py @@ -0,0 +1,1252 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import partial + +import numpy +import paddle +import paddle.nn.functional as F + +try: + import fused_ln + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + +from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import WeightGradStore + +USE_DS_GEMM = os.getenv("USE_DS_GEMM", "False").lower() == "true" + +try: + if USE_DS_GEMM: + import deep_gemm + else: + from paddle.incubate.fp8 import deep_gemm +except: + pass + + +__all__ = [ + "FP8LinearFunctionBase", + "FP8Linear", + "FP8GroupGemmMlpFunctionNode", +] + + +def set_parameter_color( + parameters, color, group=None, offline_quant_expert_weight=True, clear_origin_weight_when_offline_quant=True +): + if offline_quant_expert_weight and clear_origin_weight_when_offline_quant: + if group is None: + for p in parameters: + if hasattr(p, "color") and p.color is not None: + continue + setattr(p, "color", {"color": color}) + else: + for p in parameters: + if hasattr(p, "color") and p.color is not None: + continue + setattr(p, "color", {"color": color, "group": group}) + + +def extract_first_if_tuple(x): + return x[0] if isinstance(x, tuple) else x + + +def _get_fp8_weight_and_scale(weight, stacked=False, transpose=False): + """_get_fp8_weight_and_scale""" + if stacked: + if transpose: + fp8_weight, fp8_scale = weight.fp8_weight_stacked_transpose, weight.fp8_scale_stacked_transpose + else: + fp8_weight, fp8_scale = weight.fp8_weight_stacked, weight.fp8_scale_stacked + else: + if transpose: + fp8_weight, fp8_scale = weight.fp8_weight_transpose, weight.fp8_scale_transpose + else: + fp8_weight, fp8_scale = weight.fp8_weight, weight.fp8_scale + return fp8_weight, fp8_scale + + +def fused_stack_quant(expert_weight_list, transpose=False): + if hasattr(expert_weight_list[0], "fp8_weight_stacked"): + w, scale = _get_fp8_weight_and_scale(expert_weight_list[0], stacked=True, transpose=transpose) + else: + w, scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(expert_weight_list, transpose=transpose) + return w, scale + + +def weight_quant(weight, transpose=False): + if transpose: + if hasattr(weight, "fp8_weight_transpose"): + return weight.fp8_weight_transpose, weight.fp8_scale_transpose + else: + return paddle.incubate.nn.functional.fp8_quant_blockwise( + weight, + output_scale_transpose=False, + quant_method="128x128", + input_transpose=True, + return_transpose_only=True, + ) + else: + if hasattr(weight, "fp8_weight"): + return weight.fp8_weight, weight.fp8_scale + else: + return paddle.incubate.nn.functional.fp8_quant_blockwise( + weight, + output_scale_transpose=False, + quant_method="128x128", + input_transpose=False, + return_transpose_only=False, + ) + + +class FP8LinearFunctionBase: + @staticmethod + def dequantize_fp8_to_fp32(fp8_tensor, scale): + res = fp8_tensor.reshape([-1, 128]).astype("bfloat16") * (scale.reshape([-1, 1])) + return res.reshape(fp8_tensor.shape) + + @staticmethod + def padding(x, axis): + if x.shape[axis] % 512 != 0: + if (x.shape[axis] + 128 - (x.shape[axis] % 128)) % 512 != 0: + padding_size = 512 + else: + padding_size = 128 + pad_size = padding_size - (x.shape[axis] % padding_size) + if axis == 0: + x = paddle.concat([x, paddle.zeros([pad_size, x.shape[-1]], dtype=x.dtype)], axis=0) + else: + x = paddle.concat([x, paddle.zeros([x.shape[0], pad_size], dtype=x.dtype)], axis=-1) + return x + + @staticmethod + def padding_and_quant_input(tensor): + """Quantize input to FP8, with fallback to padded transposed version if shape not aligned.""" + if tensor.shape[0] % 512 != 0: + tensor_fp8, tensor_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + tensor, output_scale_transpose=True, quant_method="1x128", input_transpose=False + ) + tensor = FP8LinearFunctionBase.padding(tensor, 0) + tensor_t_fp8, tensor_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + tensor, + output_scale_transpose=True, + tquant_method="1x128", + input_transpose=True, + return_transpose_only=True, + ) + return tensor_fp8, tensor_scale, tensor_t_fp8, tensor_t_scale + else: + tensor_fp8, tensor_scale, tensor_t_fp8, tensor_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + tensor, output_scale_transpose=True, quant_method="1x128", input_transpose=True + ) + return tensor_fp8, tensor_scale, tensor_t_fp8, tensor_t_scale + + @staticmethod + def kitchen_gemm( + x_fp8, x_scale, w_fp8, w_scale, is_a_1d_scaled, is_b_1d_scaled, out=None, rtn_dtype=paddle.bfloat16 + ): + if USE_DS_GEMM: + if out is None: + out = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], rtn_dtype) + if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0: + deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt((x_fp8, x_scale), (w_fp8, w_scale), out, num_sms=118) + return out + + if out is not None: + accumulate = True + out_dtype = out.dtype + else: + accumulate = False + out_dtype = rtn_dtype + if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0: + y = paddle.incubate.nn.functional.fp8_gemm_blockwise( + a=x_fp8, + a_decode_scale=x_scale, + b=w_fp8, + b_decode_scale=w_scale, + out_dtype=out_dtype, + out=out, + accumulate=accumulate, + use_split_accumulator=True, + is_a_1d_scaled=is_a_1d_scaled, + is_b_1d_scaled=is_b_1d_scaled, + ) + else: + y = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], out_dtype) + if out is not None: + out = out + y + return out + + return y + + @staticmethod + def compute_fp8_linear( + input, weight, weight_transpose=False, return_transpose_only=False, return_mode="output_only", *, out=None + ): + """ + FP8 Linear 计算函数,支持多种返回模式,支持量化/未量化输入。 + + Args: + input: 输入张量(原始或已经量化的(input_fp8, input_scale) 元组)。 + weight: 权重张量。 + weight_transpose (bool): 是否转置权重。 + return_transpose_only (bool): 是否仅返回转置后的权重。 + return_mode (str): 返回模式,可选: + - "output_only": 仅返回输出张量。 + - "with_input_quant": 返回输出 + 输入量化结果 (input_fp8, input_scale)。 + - "with_input_transpose_quant": 返回输出(out) + 输入量化转置结果 (input_t_fp8, input_t_scale). + Returns: + 根据 return_mode 返回不同组合的张量。 + + Raises: + RuntimeError: 如果 return_mode 不支持。 + """ + # check input + is_input_quantized = isinstance(input, (tuple, list)) and len(input) == 2 + + if is_input_quantized: + input_fp8, input_scale = input + if return_mode == "with_input_transpose_quant": + raise RuntimeError( + "Cannot return transposed quant if input is already quantized. " "Use raw input instead." + ) + else: + # quant input (with optional transposed output) + if return_mode == "with_input_transpose_quant": + input_fp8, input_scale, input_t_fp8, input_t_scale = FP8LinearFunctionBase.padding_and_quant_input( + input + ) + else: + input_fp8, input_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + input, + output_scale_transpose=True, + quant_method="1x128", + input_transpose=False, + return_transpose_only=False, + ) + + # quant weight + weight_fp8, weight_scale = weight_quant(weight, weight_transpose) + + # FP8 GEMM + if out is None: + out = paddle.empty([input_fp8.shape[0], weight_fp8.shape[0]], dtype=weight.dtype) + + deep_gemm.gemm_fp8_fp8_bf16_nt((input_fp8, input_scale.T), (weight_fp8, weight_scale), out, num_sms=118) + + # Return outputs + if return_mode == "output_only": + return out + elif return_mode == "with_input_quant": + return (out, input_fp8, input_scale) + elif return_mode == "with_input_transpose_quant": + return (out, input_t_fp8, input_t_scale) + else: + raise RuntimeError( + f"Unsupported return_mode: {return_mode}. " + "Supported modes: 'output_only', 'with_input_quant', 'with_input_transpose_quant'" + ) + + @staticmethod + def compute_expert_w_grad( + input_t, + input_t_scale, + dout_t, + dout_t_scale, + is_a_1d_scaled=True, + is_b_1d_scaled=True, + weight=None, + rtn_dtype=paddle.bfloat16, + ): + """ + 统一处理 expert_w 的梯度计算(支持 main_grad 和普通 grad) + """ + + if input_t is None or numpy.prod(input_t.shape) == 0: + return + + if hasattr(weight, "main_grad"): + if weight.main_grad is None: + weight.main_grad = paddle.zeros(shape=weight.shape, dtype=paddle.float32) + if WeightGradStore.enabled: + WeightGradStore.put( + partial( + FP8LinearFunctionBase.kitchen_gemm, + input_t, + input_t_scale, + dout_t, + dout_t_scale, + is_a_1d_scaled, + is_b_1d_scaled, + weight.main_grad, + rtn_dtype, + ) + ) + result = None + + else: + result = FP8LinearFunctionBase.kitchen_gemm( + input_t, + input_t_scale, + dout_t, + dout_t_scale, + is_a_1d_scaled, + is_b_1d_scaled, + weight.main_grad, + rtn_dtype, + ) + else: + if weight.grad is None: + weight.grad = paddle.zeros(shape=weight.shape, dtype=paddle.float32) + result = FP8LinearFunctionBase.kitchen_gemm( + input_t, input_t_scale, dout_t, dout_t_scale, is_a_1d_scaled, is_b_1d_scaled, weight.grad, rtn_dtype + ) + + if hasattr(weight, "_apply_backward_hook"): + weight._apply_backward_hook() + return result + + @staticmethod + def common_fp8_mlp_bwd( + do3, x_t_fp8, x_t_scale, w1, w2, o1=None, x_fp8=None, x_scale=None, apply_backward_hook=False + ): + if o1 is not None and (x_fp8 is not None or x_scale is not None): + raise ValueError("When o1 is provided, both x_fp8 and x_scale must be None.") + + if o1 is None: + if x_fp8 is None or x_scale is None: + raise ValueError("When o1 is None, both x_fp8 and x_scale must be provided.") + + # # ===== [recompute] o1 = deep_gemm(x_fp8, w1_t_fp8) ===== + + # Recompute o1 using deep_gemm(x_fp8, w1_t_fp8) + w1_fp8, w1_scale = weight_quant(w1, True) + o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=do3.dtype) + deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1, num_sms=118) + + # ===== [recompute] o2 = swiglu(o1) ===== + o2 = swiglu(o1) + + # ===== do2 = deep_gemm(do3_fp8, w2_fp8) + do2, do3_t_fp8, do3_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + do3, w2, return_mode="with_input_transpose_quant" + ) + + # ===== dw2 = deep_gemm(o2_t_fp8, do3_t_fp8) + o2 = FP8LinearFunctionBase.padding(o2, 0) + o2_t_fp8, o2_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + o2, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True + ) + if apply_backward_hook: + if WeightGradStore.enabled: + WeightGradStore.put( + partial( + FP8LinearFunctionBase.compute_expert_w_grad, + o2_t_fp8, + o2_t_scale, + do3_t_fp8, + do3_t_scale, + True, + True, + w2, + rtn_dtype=paddle.float32, + ) + ) + else: + + FP8LinearFunctionBase.compute_expert_w_grad( + o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, True, True, w2, rtn_dtype=paddle.float32 + ) + else: + dw2 = FP8LinearFunctionBase.kitchen_gemm( + o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, True, True, rtn_dtype=paddle.float32 + ) + + # ===== do1 = swiglu_grad(o1, None, do2) ===== + do1, _ = paddle._C_ops.swiglu_grad(o1, None, do2) + + # ===== dx = deep_gemm(do1_fp8, w1_fp8) ===== + dx, do1_t_fp8, do1_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + do1, w1, return_mode="with_input_transpose_quant" + ) + + # ===== dw1 = deep_gemm(x_t_fp8, do1_t_fp8) ===== + if apply_backward_hook: + if WeightGradStore.enabled: + WeightGradStore.put( + partial( + FP8LinearFunctionBase.compute_expert_w_grad, + x_t_fp8, + x_t_scale, + do1_t_fp8, + do1_t_scale, + True, + True, + w1, + rtn_dtype=paddle.float32, + ) + ) + + else: + FP8LinearFunctionBase.compute_expert_w_grad( + x_t_fp8, x_t_scale, do1_t_fp8, do1_t_scale, True, True, w1, rtn_dtype=paddle.float32 + ) + else: + dw1 = FP8LinearFunctionBase.kitchen_gemm( + x_t_fp8, x_t_scale, do1_t_fp8, do1_t_scale, True, True, rtn_dtype=paddle.float32 + ) + + if apply_backward_hook: + return dx + else: + assert dw1 is not None and dw2 is not None + return dx, dw1, dw2 + + @staticmethod + def fp8_mlp_fwd(x, w1, w2): + x_orig_shape = x.shape + x = x.reshape([-1, x_orig_shape[-1]]) + + # ===== o1 = deep_gemm(x_fp8, w1_t_fp8) ===== + o1, x_fp8, x_scale = FP8LinearFunctionBase.compute_fp8_linear( + x, w1, weight_transpose=True, return_transpose_only=True, return_mode="with_input_quant" + ) + + # ===== o2 = swiglu(o1) ===== + o2 = swiglu(o1) + + # ===== o3 = deep_gemm(o2_fp8, w2_t_fp8) ===== + o3 = FP8LinearFunctionBase.compute_fp8_linear(o2, w2, weight_transpose=True, return_transpose_only=True) + + if len(x_orig_shape) > 2: + o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]]) + + return o1, x_fp8, x_scale, o3 + + @staticmethod + def fp8_mlp_fwd_norm_rc(x, norm_w, norm_eps, w1, w2): + # ===== compute norm_output ===== + norm_output, _ = fused_ln.fused_rms_norm(x, norm_w, norm_eps) + # ===== compute fp8_mlp_fwd ===== + _, _, _, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(norm_output, w1, w2) + return o3 + + @staticmethod + def fp8_mlp_bwd(do3, x, w1, w2, apply_backward_hook=False): + do3_orig_shape = do3.shape + do3 = do3.reshape([-1, do3_orig_shape[-1]]) + + x_orig_shape = x.shape + x = x.reshape([-1, x_orig_shape[-1]]) + + x_fp8, x_scale, x_t_fp8, x_t_scale = FP8LinearFunctionBase.padding_and_quant_input(x) + + if apply_backward_hook: + dx = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, + x_t_fp8, + x_t_scale, + w1, + w2, + o1=None, + x_fp8=x_fp8, + x_scale=x_scale, + apply_backward_hook=apply_backward_hook, + ) + if len(x_orig_shape) > 2: + dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]]) + return dx + else: + dx, dw1, dw2 = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, + x_t_fp8, + x_t_scale, + w1, + w2, + o1=None, + x_fp8=x_fp8, + x_scale=x_scale, + apply_backward_hook=apply_backward_hook, + ) + if len(x_orig_shape) > 2: + dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]]) + + return dx, dw1, dw2 + + @staticmethod + def fp8_mlp_bwd_norm_rc(do3, x, norm_w, norm_eps, w1, w2): + # ===== recompute norm_output ===== + norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps) + + # ===== compute fp8_mlp_fwd ===== + d_norm_output = FP8LinearFunctionBase.fp8_mlp_bwd(do3, norm_output, w1, w2, True) + + if hasattr(norm_w, "_apply_backward_hook"): + norm_w._apply_backward_hook() + + return d_norm_output, norm_output, invar + + +class FP8LinearFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, custom_map, keep_x=False): + weight = custom_map.weight + x_orig_shape = x.shape + + # deep_gemm only support 2D + x = x.reshape([-1, x_orig_shape[-1]]).contiguous() + + if keep_x: + out = FP8LinearFunctionBase.compute_fp8_linear( + x, + weight, + weight_transpose=True, + return_transpose_only=True, + ) + # save for bwd + out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]]) + ctx.save_for_backward(x, weight) + return out + else: + x_t = x.T + out, x_t_fp8, x_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + x, weight, weight_transpose=True, return_transpose_only=True, return_mode="with_input_transpose_quant" + ) + out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]]) + ctx.save_for_backward((x_t_fp8, x_t_scale), weight) + ctx.x_t_shape = x_t.shape + return out + + @staticmethod + def backward(ctx, dout): + x, weight = ctx.saved_tensor() + dout_2d = dout.reshape([-1, dout.shape[-1]]) + + keep_x = not isinstance(x, tuple) + + if keep_x: + # padding x and quant + dx_orig_shape = x.shape + x = FP8LinearFunctionBase.padding(x, 0) + x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + x, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True + ) + + # ===== dx = deep_gemm(dout_fp8, w_fp8) + dx, dout_t_fp8, dout_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + dout_2d, weight, weight_transpose=False, return_mode="with_input_transpose_quant" + ) + dx = dx.reshape(dx_orig_shape) + + else: + x_t_fp8, x_t_scale = x + + # ===== dx = deep_gemm(dout_fp8, w_fp8) + dx, dout_t_fp8, dout_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + dout_2d, weight, weight_transpose=False, return_mode="with_input_transpose_quant" + ) + dx_orig_shape = dout.shape[:-1] + dx_orig_shape.append(ctx.x_t_shape[0]) + dx = dx.reshape(dx_orig_shape) + + # ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8) + FP8LinearFunctionBase.compute_expert_w_grad( + x_t_fp8, x_t_scale, dout_t_fp8, dout_t_scale, True, True, weight, paddle.float32 + ) + return dx + + +class FP8Linear(paddle.nn.Layer): + def __init__(self, in_features: int, out_features: int, bias_attr: bool = False) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.weight = self.create_parameter( + shape=[in_features, out_features], + dtype="bfloat16", + is_bias=False, + ) + + def forward(self, x): + return FP8LinearFunction.apply(x, self, keep_x=False) + + +def cache_fp8_weight(weight): + if hasattr(weight, "fp8_weight"): + return + w_fp8, w_scale, w_t_fp8, w_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + weight, + output_scale_transpose=False, + quant_method="128x128", + input_transpose=True, + return_transpose_only=False, + ) + + setattr(weight, "fp8_weight_transpose", w_t_fp8) + setattr(weight, "fp8_scale_transpose", w_t_scale) + setattr(weight, "fp8_weight", w_fp8) + setattr(weight, "fp8_scale", w_scale) + + +class FP8KeepXLinear(paddle.nn.Layer): + def __init__(self, in_features: int, out_features: int, bias_attr: bool = False) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.weight = self.create_parameter( + shape=[in_features, out_features], + dtype="bfloat16", + is_bias=False, + ) + set_parameter_color([self.weight], "attn_out_project") + + def fp8_quant_weight(self): + cache_fp8_weight(self.weight) + + def forward(self, x): + return FP8LinearFunction.apply(x, self, keep_x=True) + + +class FusedNormFP8MLPFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, norm_w, w1, w2, norm_eps): + # ===== compute norm_output ===== + norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps) + # ===== reshape for deep_gemm, since deep_gemm only support 2D ===== + x_orig_shape = norm_output.shape + norm_output = norm_output.reshape([-1, x_orig_shape[-1]]) + + # ===== call func fp8_mlp_fwd ===== + _, _, _, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(norm_output, w1, w2) + + # ===== reshape to origin shape ===== + if len(x_orig_shape) > 2: + o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]]) + + # ===== save for backward ===== + ctx.save_for_backward( + norm_output, + invar, + x, + norm_w, + w1, + w2, + norm_eps, + paddle.to_tensor(x_orig_shape, dtype="int64", place=paddle.CPUPlace()), + ) + return o3 + + @staticmethod + def backward(ctx, do3): + # ===== reshape for deep_gemm, since deep_gemm only support 2D ===== + do3_orig_shape = do3.shape + do3 = do3.reshape([-1, do3_orig_shape[-1]]) + + # ===== recive saved tensors ===== + norm_output, invar, x, norm_w, w1, w2, norm_eps, x_orig_shape = ctx.saved_tensor() + + x_fp8, x_scale, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + norm_output, output_scale_transpose=True, quant_method="1x128", input_transpose=True + ) + + # ===== call func common_fp8_mlp_bwd ===== + d_norm_output, dw1, dw2 = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, x_t_fp8, x_t_scale, w1, w2, o1=None, x_fp8=x_fp8, x_scale=x_scale + ) + + # ===== reshape to origin shape ===== + if len(x_orig_shape) > 2: + d_norm_output = d_norm_output.reshape([x_orig_shape[0], -1, d_norm_output.shape[-1]]) + + # ===== compute norm grad ===== + dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, norm_w, invar, d_norm_output, norm_eps) + + return dx, d_rms_norm_weight, dw1, dw2 + + +class FP8MlpFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, w1, w2, recompute_fwd_gate_up): + # ===== reshape for deep_gemm, since deep_gemm only support 2D ===== + x_orig_shape = x.shape + x = x.reshape([-1, x_orig_shape[-1]]) + + # ===== call func fp8_mlp_fwd ===== + o1, x_fp8, x_scale, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(x, w1, w2) + # ===== reshape to origin shape ===== + if len(x_orig_shape) > 2: + o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]]) + + # ===== save for backward ===== + o1 = None if recompute_fwd_gate_up else o1 + ctx.save_for_backward( + o1, + x_fp8, + x_scale, + w1, + w2, + paddle.to_tensor(x_orig_shape, dtype="int64", place=paddle.CPUPlace()), + ) + return o3 + + @staticmethod + def backward(ctx, do3): + # ===== reshape for deep_gemm, since deep_gemm only support 2D ===== + do3_orig_shape = do3.shape + do3 = do3.reshape([-1, do3_orig_shape[-1]]) + + # ===== recive saved tensors ===== + o1, x_fp8, x_scale, w1, w2, x_orig_shape = ctx.saved_tensor() + + # ===== compute x_t_fp8, x_t_scale for dw1 ===== + x_dequant_fp16 = paddle.incubate.nn.functional.fused_act_dequant(x_fp8, x_scale.T.contiguous()) + x_dequant_fp16 = FP8LinearFunctionBase.padding(x_dequant_fp16, 0) + + x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + x_dequant_fp16, + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + return_transpose_only=True, + ) + + # ===== call func common_fp8_mlp_bwd ===== + if o1 is None: + dx = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, x_t_fp8, x_t_scale, w1, w2, o1=None, x_fp8=x_fp8, x_scale=x_scale, apply_backward_hook=True + ) + else: + dx = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, x_t_fp8, x_t_scale, w1, w2, o1=o1, x_fp8=None, x_scale=None, apply_backward_hook=True + ) + # ===== reshape to origin shape ===== + if len(x_orig_shape) > 2: + dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]]) + + return dx, None, None + + +class FP8Mlp(paddle.nn.Layer): + def __init__( + self, + config, + hidden_size=None, + intermediate_size=None, + is_moe=False, + using_post_norm_recompute=False, + norm_weight=None, + norm_eps=None, + recompute_fwd_gate_up=False, + ): + super().__init__() + self.config = config + self.using_post_norm_recompute = using_post_norm_recompute + if self.using_post_norm_recompute: + assert norm_weight is not None and norm_eps is not None + self.norm_weight = norm_weight + self.norm_eps = norm_eps + + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + + self.recompute_fwd_gate_up = recompute_fwd_gate_up + + self.w1 = self.create_parameter( + shape=[self.hidden_size, self.intermediate_size * 2], + dtype="bfloat16", + is_bias=False, + ) + self.w2 = self.create_parameter( + shape=[self.intermediate_size, self.hidden_size], + dtype="bfloat16", + is_bias=False, + ) + + def fp8_quant_weight(self): + cache_fp8_weight(self.w1) + cache_fp8_weight(self.w2) + + def forward(self, x): + if self.using_post_norm_recompute: + return FusedNormFP8MLPFunction.apply(x, self.norm_weight, self.w1, self.w2, self.norm_eps) + else: + return FP8MlpFunction.apply(x, self.w1, self.w2, self.recompute_fwd_gate_up) + + +def split_group_gemm(x_fp8, x_scale, w_fp8, w_scale, tokens_per_expert, gemm_out): + start_idx = 0 + for i, token_num in enumerate(tokens_per_expert): + if token_num == 0: + continue + end_idx = start_idx + token_num + + x_scale_tma_align = x_scale[start_idx:end_idx].T.contiguous().T + + deep_gemm.gemm_fp8_fp8_bf16_nt( + (x_fp8[start_idx:end_idx], x_scale_tma_align), + (w_fp8[i], w_scale[i]), + gemm_out[start_idx:end_idx], + num_sms=118, + ) + + start_idx = end_idx + + return gemm_out + + +class FP8GroupGemmMlpFunctionNode: + def __init__( + self, + custom_map, + recompute_fwd_gate_up=False, + is_split_group_gemm=False, + name="experts_group_gemm_contiguous_node", + ): + self.experts = custom_map.experts + self.recompute_fwd_gate_up = recompute_fwd_gate_up + self.is_split_group_gemm = is_split_group_gemm + self.m_indices = None + self.input = None + self.input_fp8 = None + self.input_scale = None + self.o1 = None + self.all_unzipped_grad = None + self.fwd_subbatch = None + self.bwd_subbatch = None + + def reset_statue(self): + self.m_indices = None + self.fwd_subbatch = None + self.bwd_subbatch = None + self.clear_activation_tensors() + + def clear_activation_tensors(self): + self.input = None + self.input_fp8 = None + self.input_scale = None + self.o1 = None + self.all_unzipped_grad = None + + def gen_m_indices(self, tokens_per_expert): + tokens = [] + for i in range(len(tokens_per_expert)): + tokens.append(paddle.full([tokens_per_expert[i]], i, dtype="int32")) + out = paddle.concat(tokens, axis=0) + return out + + def fwd_gate_up(self, x, expert_w1, num_expert, tokens_per_expert, m_indices=None): + """ + o1 = x * w1 + [m_sum, n] = [m_sum, k] * [num_groups, k, n] (m_sum = sum(tokens_per_expert)) + """ + if not self.is_split_group_gemm and self.m_indices is None: + self.m_indices = self.gen_m_indices(tokens_per_expert) + # concat w1, shape is [num_groups, n, k] + w1_t_quant, w1_t_scale = fused_stack_quant(expert_w1, transpose=True) + w1_t_quant = w1_t_quant.reshape([num_expert, -1, w1_t_quant.shape[-1]]) + w1_t_scale = w1_t_scale.reshape([num_expert, -1, w1_t_scale.shape[-1]]) + + if x is None: + x_fp8, x_scale = self.input_fp8, self.input_scale + assert x_fp8 is not None and x_scale is not None + else: + if isinstance(x, tuple): + (x_fp8, x_scale) = x + x_scale = paddle.transpose(paddle.transpose(x_scale, [1, 0]).contiguous(), [1, 0]) + else: + # quant x_bf16 + x_fp8, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + x, output_scale_transpose=True, quant_method="1x128", input_transpose=False + ) + x_scale = x_scale.T + + # compute gemm + o1 = paddle.empty([x_fp8.shape[0], w1_t_quant.shape[1]], dtype=expert_w1[0].dtype) + if numpy.prod(x_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm(x_fp8, x_scale, w1_t_quant, w1_t_scale, tokens_per_expert, o1) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (x_fp8, x_scale), + (w1_t_quant, w1_t_scale), + o1, + m_indices=self.m_indices if m_indices is None else m_indices, + num_sms=118, + ) + + if m_indices is None: + self.input_fp8 = x_fp8 + self.input_scale = x_scale + return o1 + + def fwd_swiglu(self, o1): + o2 = swiglu(o1) + return o2 + + def fwd_down( + self, o1, unzipped_probs, expert_w2, num_expert, tokens_per_expert, m_indices=None, o3=None, clear_o1=False + ): + """ + o3 = o2 * w2 + [m_sum, k] = [m_sum, n] * [num_groups, n, k] + """ + # concat and transpose w2 + w2_quant, w2_scale = fused_stack_quant(expert_w2, transpose=True) + w2_quant = w2_quant.reshape([num_expert, -1, w2_quant.shape[-1]]) + w2_scale = w2_scale.reshape([num_expert, -1, w2_scale.shape[-1]]) + + # quant o2 + with paddle.amp.auto_cast(False): + unzipped_probs = unzipped_probs.squeeze(-1) + o2_fp8, o2_scale = paddle.incubate.nn.functional.fused_weighted_swiglu_act_quant( + o1, unzipped_probs, using_pow2_scaling=True + ) + o2_scale = paddle.transpose(paddle.transpose(o2_scale, [1, 0]).contiguous(), [1, 0]) + + if clear_o1: + o1._clear_to_zero_allocation() + + # compute gemm + o3_shape = [o2_fp8.shape[0], w2_quant.shape[1]] + if o3 is not None: + assert o3.shape == o3_shape, "{} vs {}".format(o3.shape, o3_shape) + else: + o3 = paddle.empty(o3_shape, dtype=o1.dtype) + if numpy.prod(o2_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm(o2_fp8, o2_scale, w2_quant, w2_scale, tokens_per_expert, o3) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (o2_fp8, o2_scale), + (w2_quant, w2_scale), + o3, + m_indices=m_indices if self.fwd_subbatch else self.m_indices, + num_sms=118, + ) + + return o3 + + def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, tokens_per_expert, m_indices=None, unzipped_probs=None): + """ + do2 = do3 * w2_t + [m_sum, n] = [m_sum, k] * [num_groups, k, n] + """ + # recompute concated_w2_2d + bw_w2_quant, bw_w2_scale = fused_stack_quant(expert_w2, transpose=False) + bw_w2_quant = bw_w2_quant.reshape([len(expert_w2), -1, bw_w2_quant.shape[-1]]) + bw_w2_scale = bw_w2_scale.reshape([len(expert_w2), -1, bw_w2_scale.shape[-1]]) + + # compute gemm + if isinstance(unzipped_grad, tuple): + (unzipped_grad_fp8, unzipped_grad_scale) = unzipped_grad + unzipped_grad_scale = unzipped_grad_scale.T.contiguous().T + else: + unzipped_grad_fp8, unzipped_grad_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + unzipped_grad, output_scale_transpose=True, quant_method="1x128", input_transpose=False + ) + unzipped_grad_scale = unzipped_grad_scale.T + + do2_s = paddle.empty([unzipped_grad_fp8.shape[0], bw_w2_quant.shape[1]], dtype="bfloat16") + if numpy.prod(unzipped_grad_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm( + unzipped_grad_fp8, unzipped_grad_scale, bw_w2_quant, bw_w2_scale, tokens_per_expert, do2_s + ) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (unzipped_grad_fp8, unzipped_grad_scale), + (bw_w2_quant, bw_w2_scale), + do2_s, + m_indices=m_indices if self.bwd_subbatch else self.m_indices, + num_sms=118, + ) + + with paddle.amp.auto_cast(False): + do1, probs_grad, o2_s = paddle.incubate.nn.functional.fused_swiglu_weighted_bwd(o1, do2_s, unzipped_probs) + + return do1, o2_s, probs_grad + + def bwd_swiglu(self, o1, do2): + do1, _ = paddle._C_ops.swiglu_grad(o1, None, do2) + return do1 + + def bwd_gate_up_input(self, do1, expert_w1, tokens_per_expert, m_indices=None, dx=None): + """ + dx = do1 * w1_t + [m_sum, k] = [m_sum, n] * [num_groups, n, k] + """ + # recompute concated_w1_t + bw_w1_quant, bw_w1_scale = fused_stack_quant(expert_w1, transpose=False) + bw_w1_quant = bw_w1_quant.reshape([len(expert_w1), -1, bw_w1_quant.shape[-1]]) + bw_w1_scale = bw_w1_scale.reshape([len(expert_w1), -1, bw_w1_scale.shape[-1]]) + + # quant do1 + do1_fp8, do1_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + do1, output_scale_transpose=True, quant_method="1x128", input_transpose=False + ) + do1_scale = do1_scale.T + # compute gemm + dx_shape = [do1_fp8.shape[0], bw_w1_quant.shape[1]] + if dx is None or dx.dtype != do1.dtype: + dx = paddle.empty(shape=dx_shape, dtype=do1.dtype) + else: + assert dx.shape == dx_shape, f"{dx.shape} vs {dx_shape}" + if numpy.prod(do1_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm(do1_fp8, do1_scale, bw_w1_quant, bw_w1_scale, tokens_per_expert, dx) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (do1_fp8, do1_scale), + (bw_w1_quant, bw_w1_scale), + dx, + m_indices=m_indices if self.bwd_subbatch else self.m_indices, + num_sms=118, + ) + + return dx + + def fused_transpose_split_quant(self, x, scale, tokens_per_expert, pow_2_scales): + out, scale = paddle.incubate.nn.functional.fused_transpose_split_quant( + x, scale, tokens_per_expert, pow_2_scales + ) + return out, scale + + def bwd_down_weight(self, do3, o2, expert_w2, tokens_per_expert): + """ + dw2 = do2_t * do3 + [n, k] = [n, m_sum] * [m_sum, k] (m_sum = sum(tokens_per_expert)) + """ + if isinstance(o2, tuple): + o2_t_fp8, o2_t_scale = o2 + else: + o2_t_fp8, o2_t_scale = self.fused_transpose_split_quant(o2, None, tokens_per_expert, True) + + if isinstance(do3, tuple): + do3_t_fp8, do3_t_scale = do3 + else: + do3_t_fp8, do3_t_scale = self.fused_transpose_split_quant(do3, None, tokens_per_expert, True) + + def cal_weight_fn(o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, expert_w2): + with paddle.no_grad(): + for i in range(len(expert_w2)): + FP8LinearFunctionBase.compute_expert_w_grad( + o2_t_fp8[i], + o2_t_scale[i], + do3_t_fp8[i], + do3_t_scale[i], + True, + True, + expert_w2[i], + paddle.float32, + ) + + if WeightGradStore.enabled: + WeightGradStore.put(partial(cal_weight_fn, o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, expert_w2)) + else: + cal_weight_fn(o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, expert_w2) + + def bwd_gate_up_weight( + self, + do1, + input_x, + expert_w1, + tokens_per_expert, + input_fp8_slice=None, + input_scale_slice=None, + clear_input=False, + ): + """ + dw1 = dx_t * do1 + [k, n] = [k, m_sum] * [m_sum, n] (m_sum = sum(tokens_per_expert)) + """ + if input_x is None: + inp = (input_fp8_slice, input_scale_slice) if self.bwd_subbatch else (self.input_fp8, self.input_scale) + input_x_t_fp8, input_x_t_scale = self.fused_transpose_split_quant(inp[0], inp[1], tokens_per_expert, True) + + else: + input_x_t_fp8, input_x_t_scale = self.fused_transpose_split_quant(input_x, None, tokens_per_expert, True) + + if clear_input: + if self.input_fp8 is not None: + self.input_fp8._clear_to_zero_allocation() + self.input_fp8 = None + if self.input_scale is not None: + self.input_scale._clear_to_zero_allocation() + self.input_scale = None + if self.input is not None: + self.input._clear_to_zero_allocation() + self.input = None + + do1_t_fp8, do1_t_scale = self.fused_transpose_split_quant(do1, None, tokens_per_expert, True) + + def cal_weight_fn(input_x_t_fp8, input_x_t_scale, do1_t_fp8, do1_t_scale, expert_w1): + with paddle.no_grad(): + for i in range(len(expert_w1)): + FP8LinearFunctionBase.compute_expert_w_grad( + input_x_t_fp8[i], + input_x_t_scale[i], + do1_t_fp8[i], + do1_t_scale[i], + True, + True, + expert_w1[i], + paddle.float32, + ) + + if WeightGradStore.enabled: + WeightGradStore.put( + partial(cal_weight_fn, input_x_t_fp8, input_x_t_scale, do1_t_fp8, do1_t_scale, expert_w1) + ) + else: + cal_weight_fn(input_x_t_fp8, input_x_t_scale, do1_t_fp8, do1_t_scale, expert_w1) + + @paddle.no_grad() + def forward(self, hs_out, unzipped_probs, tokens_per_expert, m_indices=None): + # check subbatch + if self.fwd_subbatch: + assert m_indices is not None + # deal 0 size + dtype = paddle.bfloat16 + if hs_out is None: + assert self.input_fp8 is not None + assert self.input_scale is not None + shape = self.input_fp8.shape + else: + if isinstance(hs_out, tuple): + shape = hs_out[0].shape + else: + shape = hs_out.shape + + if shape[0] == 0: + o3 = paddle.zeros(shape, dtype=dtype) + return o3 + + # get w1/w2 + expert_w1 = [x.w1 for x in self.experts if x is not None] + expert_w2 = [x.w2 for x in self.experts if x is not None] + + num_expert = len(expert_w1) + + # o1 + o1 = self.fwd_gate_up(hs_out, expert_w1, num_expert, tokens_per_expert, m_indices) + if not self.recompute_fwd_gate_up: + self.o1 = o1 + clear_o1 = False + else: + clear_o1 = True + + # o3 + o3 = self.fwd_down( + o1, unzipped_probs, expert_w2, num_expert, tokens_per_expert, clear_o1=clear_o1, m_indices=m_indices + ) + + # save for bwd + return o3 + + @paddle.no_grad() + def backward( + self, + out_grad, + unzipped_probs, + tokens_per_expert, + input_fp8_slice=None, + input_scale_slice=None, + m_indices=None, + reset_status=False, + ): + # check subbatch + if self.bwd_subbatch: + assert ( + m_indices is not None + and input_fp8_slice is not None + and input_scale_slice is not None + and tokens_per_expert is not None + ) + # deal 0 size + dtype = paddle.bfloat16 + shape = out_grad[0].shape if isinstance(out_grad, tuple) else out_grad.shape + if shape[0] == 0: + return paddle.zeros_like(extract_first_if_tuple(out_grad), dtype=dtype), paddle.zeros_like(unzipped_probs) + + # recompute expert_w2 and expert_w1 + expert_w1 = [x.w1 for x in self.experts if x is not None] + expert_w2 = [x.w2 for x in self.experts if x is not None] + + if self.recompute_fwd_gate_up: + inp = None if not self.bwd_subbatch else (input_fp8_slice, input_scale_slice) + o1 = self.fwd_gate_up(inp, expert_w1, len(expert_w1), tokens_per_expert, m_indices=m_indices) + else: + o1 = self.o1 + + # do2 + do1, o2_s, probs_grad = self.bwd_dowm_input( + expert_w2, out_grad, o1, tokens_per_expert, unzipped_probs=unzipped_probs, m_indices=m_indices + ) + del o1 + if self.o1 is not None: + self.o1._clear_to_zero_allocation() + self.o1 = None + + # dw1 + self.bwd_gate_up_weight( + do1, + None, + expert_w1, + tokens_per_expert, + input_fp8_slice=input_fp8_slice, + input_scale_slice=input_scale_slice, + clear_input=reset_status, + ) + + if reset_status: + if self.input_fp8 is not None: + self.input_fp8._clear_to_zero_allocation() + self.input_fp8 = None + if self.input_scale is not None: + self.input_scale._clear_to_zero_allocation() + self.input_scale = None + if self.input is not None: + self.input._clear_to_zero_allocation() + self.input = None + + # dx + dx = self.bwd_gate_up_input( + do1, + expert_w1, + tokens_per_expert, + dx=out_grad[0] if isinstance(out_grad, tuple) else out_grad, + m_indices=m_indices, + ) + del do1 + + # dw2 + if isinstance(out_grad, tuple): + do3_fp8, do3_scale = self.fused_transpose_split_quant(out_grad[0], out_grad[1], tokens_per_expert, True) + out_grad[0]._clear_to_zero_allocation() + out_grad[1]._clear_to_zero_allocation() + self.bwd_down_weight((do3_fp8, do3_scale), o2_s, expert_w2, tokens_per_expert) + else: + self.bwd_down_weight(out_grad, o2_s, expert_w2, tokens_per_expert) + + if reset_status: + self.reset_statue() + return dx, probs_grad diff --git a/paddleformers/transformers/fused_a2a.py b/paddleformers/transformers/fused_a2a.py index 7b5fa09c9e0..400f97cd0a4 100644 --- a/paddleformers/transformers/fused_a2a.py +++ b/paddleformers/transformers/fused_a2a.py @@ -72,78 +72,144 @@ def get_buffer(group: Group, hidden_bytes: int): return _buffer +def fused_dispatch_forward_func( + x, + token_indices, + token_probs, + num_experts, + group, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, +): + """Forward pass of fused dispatch.""" + # Calculate layout before actual dispatch + if isinstance(x, tuple): + buffer = get_buffer(group, get_hidden_bytes(x[0])) + else: + buffer = get_buffer(group, get_hidden_bytes(x)) + ( + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + _previous_event, + ) = buffer.get_dispatch_layout( + token_indices, + num_experts, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + # Do MoE dispatch + # NOTES: the CPU will wait for GPU's signal to arrive, + # so this is not compatible with CUDA graph + (recv_x, recv_token_indices, recv_token_probs, num_recv_tokens_per_expert_list, handle, event,) = buffer.dispatch( + x, + topk_idx=token_indices, + topk_weights=token_probs.cast(paddle.float32), + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + states = dict() + states["dispatched_indices"] = recv_token_indices + states["tokens_per_expert"] = num_recv_tokens_per_expert_list + states["handle"] = handle + + return recv_x, recv_token_probs, states, event + + +def fused_dispatch_backward_func( + grad_output, + grad_token_probs, + group, + handle, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, +): + """Backward pass of fused dispatch.""" + buffer = get_buffer(group, get_hidden_bytes(grad_output)) + + grad_x, grad_token_probs, event = buffer.combine( + grad_output.contiguous(), + handle, + topk_weights=grad_token_probs.cast(paddle.float32), + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return grad_x, None, grad_token_probs + + +def fused_combine_forward_func( + x, group, states, previous_event=None, async_finish=False, allocate_on_comm_stream=False +): + """Forward pass of fused combine.""" + handle = states["handle"] + buffer = get_buffer(group, get_hidden_bytes(x)) + combined_x, _, event = buffer.combine( + x, + handle=handle, + async_finish=async_finish, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return combined_x + + +def fused_combine_backward_func( + grad_output, group, handle, previous_event=None, async_finish=False, allocate_on_comm_stream=False +): + """Backward pass of fused combine.""" + if isinstance(grad_output, tuple): + buffer = get_buffer(group, get_hidden_bytes(grad_output[0])) + grad_x, _, _, _, _, event = buffer.dispatch( + (grad_output[0].contiguous(), grad_output[1].contiguous()), + handle=handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + else: + buffer = get_buffer(group, get_hidden_bytes(grad_output)) + grad_x, _, _, _, _, event = buffer.dispatch( + grad_output.contiguous(), + handle=handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return grad_x + + class FusedDispatch(PyLayer): """Fused dispatch operation for MoE routing combining computation and communication.""" @staticmethod def forward(ctx, x, token_indices, token_probs, num_experts, group, previous_event=None): """Forward pass of fused dispatch.""" - # Calculate layout before actual dispatch - buffer = get_buffer(group, get_hidden_bytes(x)) - ( - num_tokens_per_rank, - num_tokens_per_rdma_rank, - num_tokens_per_expert, - is_token_in_rank, - previous_event, - ) = buffer.get_dispatch_layout( - token_indices, - num_experts, - previous_event=None, - async_finish=False, - allocate_on_comm_stream=False, - ) - - # Do MoE dispatch - # NOTES: the CPU will wait for GPU's signal to arrive, - # so this is not compatible with CUDA graph - ( - recv_x, - recv_token_indices, - recv_token_probs, - num_recv_tokens_per_expert_list, - handle, - event, - ) = buffer.dispatch( - x, - topk_idx=token_indices, - topk_weights=token_probs.cast(paddle.float32), - num_tokens_per_rank=num_tokens_per_rank, - num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, - is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=num_tokens_per_expert, - previous_event=None, - async_finish=False, - allocate_on_comm_stream=False, + recv_x, recv_token_probs, states, event = fused_dispatch_forward_func( + x, token_indices, token_probs, num_experts, group, previous_event ) ctx.group = group - ctx.handle = handle + ctx.handle = states["handle"] ctx.event = event - tokens_per_expert = paddle.to_tensor(num_recv_tokens_per_expert_list) - - states = dict() - states["dispatched_indices"] = recv_token_indices - states["tokens_per_expert"] = tokens_per_expert - states["handle"] = handle return recv_x, recv_token_probs, states @staticmethod def backward(ctx, grad_output, grad_token_probs): """Backward pass of fused dispatch.""" - buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) - handle = ctx.handle - - grad_x, grad_token_probs, event = buffer.combine( - grad_output.contiguous(), - handle, - topk_weights=grad_token_probs.cast(paddle.float32), - previous_event=None, - async_finish=False, - allocate_on_comm_stream=False, - ) - return grad_x, None, grad_token_probs + return fused_dispatch_backward_func(grad_output, grad_token_probs, ctx.group, ctx.handle) class FusedCombine(PyLayer): @@ -152,12 +218,9 @@ class FusedCombine(PyLayer): @staticmethod def forward(ctx, x, group, states, previous_event=None): """Forward pass of fused combine.""" - handle = states["handle"] - buffer = get_buffer(group, get_hidden_bytes(x)) - combined_x, _, event = buffer.combine( - x, handle=handle, async_finish=False, previous_event=None, allocate_on_comm_stream=False - ) - ctx.handle = handle + combined_x = fused_combine_forward_func(x, group, states, previous_event) + + ctx.handle = states["handle"] ctx.group = group ctx.previous_event = previous_event @@ -166,15 +229,7 @@ def forward(ctx, x, group, states, previous_event=None): @staticmethod def backward(ctx, grad_output): """Backward pass of fused combine.""" - buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) - grad_x, _, _, _, _, event = buffer.dispatch( - grad_output.contiguous(), - handle=ctx.handle, - previous_event=ctx.previous_event, - async_finish=False, - allocate_on_comm_stream=False, - ) - return grad_x + return fused_combine_backward_func(grad_output, ctx.group, ctx.handle, ctx.previous_event) if HAVE_DEEP_EP: @@ -214,3 +269,96 @@ def fused_combine(x, group, handle, previous_event=None): else: fused_dispatch = None fused_combine = None + + +class DispatchNode: + def __init__(self, name="dispatch"): + self.name = name + + def reset_statue(self): + self.handle = None + + def forward( + self, + x, + token_indices, + token_probs, + num_experts, + group, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, + ): + """Forward pass of fused dispatch.""" + recv_x, recv_token_probs, states, event = fused_dispatch_forward_func( + x, + token_indices, + token_probs, + num_experts, + group, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + self.group = group + self.handle = states["handle"] + self.event = event + + return recv_x, recv_token_probs, states + + def backward( + self, grad_output, grad_token_probs, previous_event=None, async_finish=False, allocate_on_comm_stream=False + ): + """Backward pass of fused dispatch.""" + out = fused_dispatch_backward_func( + grad_output, + grad_token_probs, + self.group, + self.handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + self.reset_statue() + return out + + +class CombineNode: + def __init__(self, name="combine"): + self.name = name + + def reset_statue(self): + self.handle = None + + def forward(self, x, group, handle, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + """Forward pass of fused combine.""" + states = dict() + states["handle"] = handle + combined_x = fused_combine_forward_func( + x, + group, + states, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + self.handle = handle + self.group = group + self.previous_event = previous_event + + return combined_x + + def backward(self, grad_output, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + """Backward pass of fused combine.""" + out = fused_combine_backward_func( + grad_output, + self.group, + self.handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + self.reset_statue() + return out \ No newline at end of file diff --git a/paddleformers/transformers/moe_layer.py b/paddleformers/transformers/moe_layer.py index 340fba1f524..f14e45d13a6 100644 --- a/paddleformers/transformers/moe_layer.py +++ b/paddleformers/transformers/moe_layer.py @@ -16,6 +16,7 @@ # limitations under the License. from __future__ import annotations +import os from typing import Any, List, Tuple import numpy as np @@ -24,8 +25,48 @@ from paddle import Tensor, nn from paddle.distributed.communication.group import Group +from ..utils.log import logger +from .fp8_utils import FP8GroupGemmMlpFunctionNode, extract_first_if_tuple +from .fused_a2a import CombineNode, DispatchNode, get_buffer, get_hidden_bytes from .moe_gate import PretrainedMoEGate -from .token_dispatcher import MoEFlexTokenDispatcher +from .moe_utils import ( + UnZipNode, + ZipNode, + merge_subbatch_cast, + offload, + reload, + tokens_zip_unique_add_with_subbatch, +) +from .token_dispatcher import MoEFlexTokenDispatcher, PreDispatchNode + +try: + import paddle.distributed.communication.deep_ep as deep_ep +except ImportError: + deep_ep = None + +DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true" + +DSV3_USE_FP8_GROUP_GEMM = os.getenv("DSV3_USE_FP8_GROUP_GEMM", "False").lower() == "true" + +DSV3_USE_FP8_DISPATCH = os.getenv("DSV3_USE_FP8_DISPATCH", "False").lower() == "true" + + +import TokenDispatcherUtils as TDU + + +def record_stream_for_multi_input(x): + if isinstance(x, (tuple, list)): + for i in range(len(x)): + x[i]._record_stream() + else: + x._record_stream() + + +def stop_gradient_for_multi_input(x): + if isinstance(x, (tuple, list)): + x[0].stop_gradient = False + else: + x.stop_gradient = False def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): @@ -162,6 +203,7 @@ def __init__( capacity: int = 1.0, moe_group: str = "data", all_to_all_dropout=0.0, + using_post_norm_recompute=False, ): super().__init__() @@ -176,12 +218,11 @@ def __init__( except AttributeError: is_fleet_init = False - if ( - is_fleet_init - and dist.fleet.get_hybrid_communicate_group().get_data_parallel_world_size() > 1 - and moe_group == "data" - ): - self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + if is_fleet_init and dist.get_world_size() > 1: + if moe_group == "data": + self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + elif moe_group == "expert": + self.moe_group = dist.fleet.get_hybrid_communicate_group().expert_parallel_group self.moe_rank = dist.get_rank(self.moe_group) self.moe_rank = 0 if self.moe_rank < 0 else self.moe_rank self.expert_parallel_degree = dist.get_world_size(self.moe_group) @@ -210,8 +251,32 @@ def __init__( self.gate = gate self.gate.group = self.moe_group + # for flex token moe layer + self.router = gate + self.ep_size = dist.get_world_size(self.moe_group) + self.moe_router_topk = gate.top_k + self.num_local_experts = moe_num_experts // self.ep_size + self.token_dispatcher = MoEFlexTokenDispatcher( + self.num_local_experts, self.moe_router_topk, self.moe_num_experts, self.moe_group + ) + self.token_drop_steps = config.token_drop_steps + self.using_flex_token = False + + self.using_post_norm_recompute = using_post_norm_recompute self._post_init() + def update_flex_token(self): + from paddleformers.transformers.deepseek_v2 import get_global_step + + if (not self.config.using_flex_token) or (get_global_step() < self.token_drop_steps): + self.using_flex_token = False + self.router.using_flex_token = False + else: + if not self.using_flex_token: + logger.info("Changing to flex token moe mode") + self.using_flex_token = True + self.router.using_flex_token = True + def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree): assert ( moe_num_experts >= expert_parallel_degree @@ -234,8 +299,35 @@ def _post_init(self): # logger.info(f"expert param={p.name}, no-sync={p.no_sync}") def forward( + self, + hidden_states: paddle.Tensor, + probs=None, + routing_map=None, + capacity=None, + topk_weight=None, + topk_ids=None, + token_priority=None, + l_aux=None, + l_zloss=None, + ): + self.update_flex_token() + + if self.using_flex_token: + return self.forward_flex_token(hidden_states, probs, routing_map, l_aux, l_zloss) + else: + return self.forward_drop_token( + hidden_states, capacity, topk_weight, topk_ids, token_priority, l_aux, l_zloss + ) + + def forward_drop_token( self, hidden_state: paddle.Tensor, + capacity=None, + topk_weight=None, + topk_ids=None, + token_priority=None, + l_aux=None, + l_zloss=None, ): """MoE Layer forward function 1. Gate Forward. @@ -257,7 +349,17 @@ def forward( # topk_ids : sk # token_priority : se # self.exp_counts : - capacity, topk_weight, topk_ids, token_priority, l_aux, l_zloss = self.gate(hidden_state) + if self.using_post_norm_recompute: + assert ( + capacity is not None + and topk_weight is not None + and topk_ids is not None + and token_priority is not None + and l_aux is not None + and l_zloss is not None + ) + else: + capacity, topk_weight, topk_ids, token_priority, l_aux, l_zloss = self.gate(hidden_state) """MoE expert dispatch from: https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py""" cnts = paddle.zeros([topk_ids.shape[0], len(self.experts)], dtype=topk_ids.dtype) @@ -336,6 +438,801 @@ def forward( return final_out, l_aux, l_zloss + def forward_flex_token(self, hidden_states: paddle.Tensor, probs=None, routing_map=None, l_aux=None, l_zloss=None): + _, _, d_model = hidden_states.shape + # reshaped_input = hidden_states.reshape([-1, d_model]) + if self.using_post_norm_recompute: + assert probs is not None and routing_map is not None and l_aux is not None and l_zloss is not None + else: + probs, routing_map, l_aux, l_zloss = self.router(hidden_states) + if DSV3_USE_FP8_GEMM: + if DSV3_USE_FP8_DISPATCH: + output = FusionMoe.apply( + hidden_states, + probs, + routing_map, + self, + recompute_fwd_gate_up=self.config.recompute_fwd_gate_up, + is_split_group_gemm=self.config.is_split_group_gemm, + ) + else: + hidden_states, token_indices, token_probs = self.token_dispatcher.pre_dispatch( + hidden_states, probs, routing_map + ) + output = FusionMoe.apply( + hidden_states, + token_indices, + token_probs, + self, + recompute_fwd_gate_up=self.config.recompute_fwd_gate_up, + is_split_group_gemm=self.config.is_split_group_gemm, + ) + else: + ( + dispatched_input, + token_permuted_indices, + prob_permuted_indices, + dispatched_probs, + ) = self.token_dispatcher.token_permutation_fast(hidden_states, probs, routing_map) + + expert_output = self.expert_forward(dispatched_input) + output, _ = self.token_dispatcher.token_unpermutation_fast( + expert_output, token_permuted_indices, prob_permuted_indices, dispatched_probs, None + ) + return output, l_aux, l_zloss + + def get_tokens_per_expert(self): + return self.token_dispatcher._comm_manager.tokens_per_expert_list + + def set_tokens_per_expert(self, tokens_per_expert_list): + self.token_dispatcher._comm_manager.tokens_per_expert_list = tokens_per_expert_list + + def expert_forward(self, dispatched_input): + outputs = [] + chunks = paddle.split(dispatched_input, num_or_sections=self.get_tokens_per_expert(), axis=0) + for i, chunk in enumerate(chunks): + chunk = chunk.contiguous() + # assert chunk.shape[0] != 0, "Cannot dispatch empty input" + expert = self.experts[i + self.moe_rank * self.moe_num_experts_per_device] + outputs += [expert(chunk)] + + return paddle.concat(outputs, axis=0) + + def pre_dispatch_compute(self, hidden_states): + _, _, d_model = hidden_states.shape + probs, routing_map, l_aux, l_zloss = self.router(hidden_states) + hidden_states, token_indices, token_probs = self.token_dispatcher.pre_dispatch( + hidden_states, probs, routing_map + ) + return l_aux, l_zloss, hidden_states, token_indices, token_probs + + def post_dispatch_compute(self, hidden_states, dispatched_indices, dispatched_probs): + (global_input_tokens, token_permuted_indices, prob_permuted_indices) = self.token_dispatcher.post_dispatch( + hidden_states, dispatched_indices + ) + return (global_input_tokens, token_permuted_indices, prob_permuted_indices) + + def pre_combine_compute(self, hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs): + hidden_states = self.token_dispatcher.pre_combine( + hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs + ) + return hidden_states + + def post_combine_compute(self, hidden_states): + hidden_states = self.token_dispatcher.post_combine(hidden_states) + return hidden_states + + +class Fp8DispatchQuantNode: + def __init__(self, token_dispatcher, name="fp8_dispatch_quant_node"): + self.token_dispatcher = token_dispatcher + self.pre_dispatch_node = PreDispatchNode(token_dispatcher) + self.name = name + + @paddle.no_grad() + def forward(self, hidden_states, probs, routing_map): + # reshape + self.token_dispatcher.hidden_shape = hidden_states.shape + hs_2d = hidden_states.view([-1, self.token_dispatcher.hidden_shape[-1]]) + + if DSV3_USE_FP8_DISPATCH: + # quant + hs_fp8, hs_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + hs_2d, output_scale_transpose=False, quant_method="1x128", input_transpose=False + ) + + # pre_dispatch + token_indices, token_probs = self.pre_dispatch_node.forward(routing_map, probs) + + self.hidden_states_shape = hidden_states.shape + hs_fp8.stop_gradient = False + token_probs.stop_gradient = False + return (hs_fp8, hs_scale), token_indices, token_probs + else: + # pre_dispatch + token_indices, token_probs = self.pre_dispatch_node.forward(routing_map, probs) + + self.hidden_states_shape = hidden_states.shape + hs_2d.stop_gradient = False + token_probs.stop_gradient = False + return hs_2d, token_indices, token_probs + + @paddle.no_grad() + def backward(self, hs_grad, token_probs_grad): + # predispatch grad + probs_grad = self.pre_dispatch_node.backward(token_probs_grad) + token_probs_grad._record_stream() + + # reshape_grad + hs_grad = hs_grad.view(self.hidden_states_shape) + hs_grad._record_stream() + + return hs_grad, probs_grad, None + + +class Fp8DispatchNode: + def __init__(self, token_dispatcher, name="fp8_dispatch_node"): + self.token_dispatcher = token_dispatcher + self.dispatch_act_node = DispatchNode(token_dispatcher) + self.name = name + + @paddle.no_grad() + def forward( + self, + hs_2d, + token_indices, + token_probs, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, + ): + # dispatch + hs_2d_dispatched, dispatched_probs, states = self.dispatch_act_node.forward( + hs_2d, + token_indices, + token_probs, + self.token_dispatcher._comm_manager.num_experts, + self.token_dispatcher._comm_manager.group, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + self.token_dispatcher._comm_manager.handle = states["handle"] + self.token_dispatcher._comm_manager.tokens_per_expert = states["tokens_per_expert"] + dispatched_indices = states["dispatched_indices"] + + stop_gradient_for_multi_input(hs_2d_dispatched) + dispatched_probs.stop_gradient = False + return hs_2d_dispatched, dispatched_indices, dispatched_probs + + @paddle.no_grad() + def backward( + self, + hs_dispatched_grad, + dispatched_probs_grad, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, + ): + # dispatch grad + hs_grad, _, token_probs_grad = self.dispatch_act_node.backward( + hs_dispatched_grad, + dispatched_probs_grad, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return hs_grad, token_probs_grad + + +class Fp8CombineNode: + def __init__(self, token_dispatcher, name="fp8_combine_node"): + self.token_dispatcher = token_dispatcher + self.combine_node = CombineNode(token_dispatcher) + self.name = name + + @paddle.no_grad() + def forward(self, hidden_states_out, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + # combine + output_combine = self.combine_node.forward( + hidden_states_out, + self.token_dispatcher._comm_manager.group, + self.token_dispatcher._comm_manager.handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + output_combine.stop_gradient = False + self.token_dispatcher._comm_manager.handle = None + return output_combine + + @paddle.no_grad() + def backward(self, output_combine_grad, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + # combine grad -> fp8 + hidden_states_out_grad = self.combine_node.backward( + output_combine_grad, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return hidden_states_out_grad + + +class Fp8CombineQuantNode: + def __init__(self, token_dispatcher, moe_group=None, name="fp8_combine_quant_node"): + self.token_dispatcher = token_dispatcher + self.name = name + self.moe_group = moe_group + + @paddle.no_grad() + def forward(self, output_combine): + # post combine + output = output_combine.reshape(self.token_dispatcher.hidden_shape) + output_combine._record_stream() + self.output_combine_shape = output_combine.shape + output.stop_gradient = False + return output + + @paddle.no_grad() + def backward(self, output_grad, event_to_wait=None): + # post combine grad + if DSV3_USE_FP8_DISPATCH: + if event_to_wait is not None: + assert self.moe_group is not None + event_to_wait.comm_stream_wait(self.moe_group.id) + buffer = get_buffer(self.token_dispatcher._comm_manager.group, get_hidden_bytes(output_grad)) + custom_stream = paddle.device.Stream(stream_base=buffer.runtime.get_comm_stream()) + else: + custom_stream = paddle.device.current_stream() + with paddle.device.stream_guard(custom_stream): + output_combine_grad = paddle.reshape(output_grad, [-1, output_grad.shape[-1]]) + # output_combine_grad quant to fp8 + output_combine_grad_fp8, output_combine_grad_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + output_combine_grad, output_scale_transpose=False, quant_method="1x128", input_transpose=False + ) + output_grad._record_stream() + quant_event = None + if event_to_wait is not None: + quant_event = deep_ep.get_event_from_custom_stream(custom_stream.stream_base) + return (output_combine_grad_fp8, output_combine_grad_scale), quant_event + else: + output_combine_grad = paddle.reshape(output_grad, [-1, output_grad.shape[-1]]) + return output_combine_grad, None + + +class FusionMlpNode: + """ + The FusedMoeLayer class includes operations for unzipping, expert computation, and zipping. + """ + + def __init__( + self, + custom_map, + max_topk, + recompute_fwd_gate_up=False, + is_split_group_gemm=True, + mlp_fwd_subbatch_rows=0, + mlp_bwd_subbatch_rows=0, + output_subbatch_rows=0, + ): + self.token_dispatcher = custom_map.token_dispatcher + self.experts = custom_map.experts + self.unzip_node = UnZipNode() + self.zip_node = ZipNode() + self.experts_group_gemm_node = FP8GroupGemmMlpFunctionNode( + custom_map, + recompute_fwd_gate_up=recompute_fwd_gate_up, + is_split_group_gemm=is_split_group_gemm, + ) + + self.seq_length = custom_map.config.seq_length + self.num_experts_per_tok = custom_map.config.num_experts_per_tok + self.adaptive_remained_O1_recompute_ratio = custom_map.config.adaptive_remained_O1_recompute_ratio + + self.recompute_fwd_gate_up = recompute_fwd_gate_up + self.dispatched_indices = None + self.dispatched_probs = None + self.tokens_per_expert = None + self.padding_token_per_experts = None + self.router_topk = max_topk + self.mlp_fwd_subbatch_rows = mlp_fwd_subbatch_rows + self.mlp_bwd_subbatch_rows = mlp_bwd_subbatch_rows + self.output_subbatch_rows = output_subbatch_rows + + def set_recompute_fwd_gate_up(self, recompute_fwd_gate_up): + self.experts_group_gemm_node.recompute_fwd_gate_up = recompute_fwd_gate_up + + def reset_statue(self): + """ + 重置所有状态变量。 + + Args: + 无。 + + Returns: + 无。 + + """ + self.dispatched_indices = None + self.dispatched_probs = None + self.tokens_per_expert = None + self.padding_token_per_experts = None + self.router_topk = None + + del self.unzip_node + del self.zip_node + self.unzip_node = None + self.zip_node = None + + self.experts_group_gemm_node.reset_statue() + self.experts_group_gemm_node = None + + def prepare_env_subbatch(self, unzipped_tokens=None, unzipped_tokens_scale=None, is_fwd=True): + if is_fwd: + assert unzipped_tokens is not None and unzipped_tokens_scale is not None + self.experts_group_gemm_node.input_fp8 = unzipped_tokens + self.experts_group_gemm_node.input_scale = unzipped_tokens_scale + self.m_indices = self.experts_group_gemm_node.gen_m_indices(self.padding_token_per_experts) + self.experts_group_gemm_node.fwd_subbatch = True + else: + self.m_indices = ( + self.experts_group_gemm_node.gen_m_indices(self.padding_token_per_experts) + if not hasattr(self, "m_indices") + else self.m_indices + ) + self.experts_group_gemm_node.bwd_subbatch = True + reload(self.experts_group_gemm_node.input_fp8) + reload(self.experts_group_gemm_node.input_scale) + + def gemm_forward_subbatch( + self, + unzipped_tokens, + unzipped_tokens_scale, + unzipped_probs, + map_unzipped_indices_to_zipped, + output, + total_zipped_tokens, + padding_token_per_experts, + start_idx=None, + end_idx=None, + output_subbatch_rows=None, + ): + if start_idx is None or end_idx is None: + start_idx = 0 + end_idx = unzipped_tokens.shape[0] + start_idx = max(0, start_idx) + end_idx = min(unzipped_tokens.shape[0], end_idx) + + expert_out = self.experts_group_gemm_node.forward( + (unzipped_tokens[start_idx:end_idx], unzipped_tokens_scale[start_idx:end_idx]), + unzipped_probs[start_idx:end_idx], + padding_token_per_experts, + m_indices=self.m_indices[start_idx:end_idx], + ) + + output = tokens_zip_unique_add_with_subbatch( + output, + expert_out, + map_unzipped_indices_to_zipped[start_idx:end_idx], + total_zipped_tokens, + subbatch_rows=output_subbatch_rows, + ) + return output + + def gemm_backward_subbatch( + self, + unzipped_grad, + map_unzipped_indices_to_zipped, + total_zipped_tokens, + output, + padding_token_per_experts, + start_idx=None, + end_idx=None, + output_subbatch_rows=None, + reset_status=False, + ): + def split_list_prefix(l, start, end): + prefix_sum = [0] * (len(l) + 1) + for i in range(len(l)): + prefix_sum[i + 1] = prefix_sum[i] + l[i] + + result = [] + for i in range(len(l)): + segment_start = prefix_sum[i] + segment_end = prefix_sum[i + 1] + overlap_start = max(start, segment_start) + overlap_end = min(end, segment_end) + selected = max(0, overlap_end - overlap_start) + result.append(selected) + return result + + if start_idx is None or end_idx is None: + start_idx = 0 + end_idx = extract_first_if_tuple(unzipped_grad).shape[0] + + start_idx = max(0, start_idx) + end_idx = min(extract_first_if_tuple(unzipped_grad).shape[0], end_idx) + + # m_indices = self.experts_group_gemm_node.gen_m_indices(self.tokens_per_expert) + unzipped_inp_grad = ( + (unzipped_grad[0][start_idx:end_idx].contiguous(), unzipped_grad[1][start_idx:end_idx].contiguous()) + if isinstance(unzipped_grad, tuple) + else unzipped_grad[start_idx:end_idx].contiguous() + ) + unzipped_grad, unzipped_probs_grad = self.experts_group_gemm_node.backward( + unzipped_inp_grad, + self.unzipped_probs[start_idx:end_idx].contiguous(), + input_fp8_slice=self.experts_group_gemm_node.input_fp8[start_idx:end_idx].contiguous(), + input_scale_slice=self.experts_group_gemm_node.input_scale[start_idx:end_idx].contiguous(), + tokens_per_expert=split_list_prefix(padding_token_per_experts, start_idx, end_idx), + m_indices=self.m_indices[start_idx:end_idx].contiguous(), + reset_status=reset_status, + ) + + output = tokens_zip_unique_add_with_subbatch( + output, + unzipped_grad, + map_unzipped_indices_to_zipped[start_idx:end_idx], + zipped_rows=total_zipped_tokens, + subbatch_rows=output_subbatch_rows, + ) + + return output, unzipped_probs_grad + + @paddle.no_grad() + def forward(self, hs_2d_dispatched, dispatched_indices, dispatched_probs): + """ + 对输入数据进行前向传播计算。 + + Args: + hs_fp8_dispatched (Tensor): 表示被分派到各个专家的输入数据。 + dispatched_indices (Tensor):表示输入数据被分派到的专家索引。 + dispatched_probs (Tensor): 表示输入数据被分派到各个专家的概率。 + + Returns: + Tensor: 经过前向传播计算后的输出数据。 + + """ + self.tokens_per_expert = self.token_dispatcher._comm_manager.tokens_per_expert + self.dispatched_probs = dispatched_probs + num_experts = len(self.tokens_per_expert) + padding_token_per_experts = [(x + 127) // 128 * 128 for x in self.tokens_per_expert] + self.padding_token_per_experts = padding_token_per_experts + # 1 unzip + self.dispatched_indices = dispatched_indices.to(paddle.int32) + if DSV3_USE_FP8_DISPATCH: + ( + unzipped_tokens, + zipped_expertwise_rowmap, + unzipped_probs, + unzipped_tokens_scale, + ) = self.unzip_node.forward( + hs_2d_dispatched, + self.dispatched_indices, + dispatched_probs, + topk=self.router_topk, + num_experts=num_experts, + tokens_per_expert=self.tokens_per_expert, + ) + record_stream_for_multi_input(hs_2d_dispatched) + dispatched_indices._record_stream() + dispatched_probs._record_stream() + + total_unzipped_tokens = extract_first_if_tuple(unzipped_tokens).shape[0] + total_zipped_tokens = extract_first_if_tuple(hs_2d_dispatched).shape[0] + + # If adaptive O1 recompute is enabled, determine whether to enable recompute O1 based on the degree of imbalance + if self.recompute_fwd_gate_up == -1: + if ( + total_unzipped_tokens + > self.seq_length * self.num_experts_per_tok * self.adaptive_remained_O1_recompute_ratio + ): + # logger.debug(f"recompute_fwd_gate_up changed to True, Because the receives {unzipped_tokens.shape[0]} Tensors greater then {self.seq_length*self.num_experts_per_tok*self.adaptive_remained_O1_recompute_ratio}.") + self.set_recompute_fwd_gate_up(True) + else: + # logger.debug(f"recompute_fwd_gate_up changed to False, Because the receives {unzipped_tokens.shape[0]} Tensors less then {self.seq_length*self.num_experts_per_tok*self.adaptive_remained_O1_recompute_ratio}.") + self.set_recompute_fwd_gate_up(False) + + self.unzipped_probs = unzipped_probs.unsqueeze(-1) + + # if use_mlp_subbatch is enabled, then split the unzipped_tokens into subbatches + if self.mlp_fwd_subbatch_rows != 0 and total_unzipped_tokens > self.mlp_fwd_subbatch_rows * 2: + assert ( + self.experts_group_gemm_node.recompute_fwd_gate_up + ), "recompute_fwd_gate_up must be true when use_mlp_subbatch = True" + map_unzipped_indices_to_zipped = TDU.tokens_unzip_slice( + extract_first_if_tuple(hs_2d_dispatched), + zipped_expertwise_rowmap, + num_experts, + total_unzipped_tokens, + 0, + total_unzipped_tokens + 1, + ) + if isinstance(hs_2d_dispatched, tuple): + hs_2d_dispatched[0]._clear_to_zero_allocation() + hs_2d_dispatched[1]._clear_to_zero_allocation() + else: + hs_2d_dispatched._clear_to_zero_allocation() + + subbatch_rows = min((total_unzipped_tokens // num_experts) // 128 * 128, self.mlp_fwd_subbatch_rows) + nparts = (total_unzipped_tokens + subbatch_rows - 1) // subbatch_rows + output = paddle.empty([0, extract_first_if_tuple(hs_2d_dispatched).shape[-1]], dtype=paddle.float32) + self.prepare_env_subbatch(unzipped_tokens, unzipped_tokens_scale, True) + logger.info( + f"Enable subbatch_forward!! total_zipped_tokens:{total_zipped_tokens}, total_unzipped_tokens:{total_unzipped_tokens}, nparts:{nparts}, subbatch_rows:{subbatch_rows}, output_sub_rows:{self.output_subbatch_rows}" + ) + for i in range(nparts): + start_idx = i * subbatch_rows + end_idx = min(start_idx + subbatch_rows, total_unzipped_tokens) + output = self.gemm_forward_subbatch( + unzipped_tokens, + unzipped_tokens_scale, + unzipped_probs, + map_unzipped_indices_to_zipped, + output, + total_zipped_tokens, + padding_token_per_experts, + start_idx=start_idx, + end_idx=end_idx, + output_subbatch_rows=self.output_subbatch_rows, + ) + + output = merge_subbatch_cast(output, paddle.bfloat16) + output.stop_gradient = False + offload(self.experts_group_gemm_node.input_fp8) + offload(self.experts_group_gemm_node.input_scale) + return output + + # 2 experts + expert_out = self.experts_group_gemm_node.forward( + (unzipped_tokens, unzipped_tokens_scale), unzipped_probs, padding_token_per_experts + ) + else: + (unzipped_tokens, zipped_expertwise_rowmap, unzipped_probs, _,) = self.unzip_node.forward( + hs_2d_dispatched, + self.dispatched_indices, + dispatched_probs, + topk=self.router_topk, + num_experts=num_experts, + tokens_per_expert=self.tokens_per_expert, + ) + hs_2d_dispatched._record_stream() + dispatched_indices._record_stream() + dispatched_probs._record_stream() + + # If adaptive O1 recompute is enabled, determine whether to enable recompute O1 based on the degree of imbalance + if self.recompute_fwd_gate_up == -1: + if ( + unzipped_tokens.shape[0] + > self.seq_length * self.num_experts_per_tok * self.adaptive_remained_O1_recompute_ratio + ): + self.set_recompute_fwd_gate_up(True) + else: + self.set_recompute_fwd_gate_up(False) + + # 2 experts + expert_out = self.experts_group_gemm_node.forward( + unzipped_tokens, unzipped_probs, padding_token_per_experts + ) + + # 3 zip + if isinstance(hs_2d_dispatched, tuple): + hs_2d_dispatched[0]._clear_to_zero_allocation() + hs_2d_dispatched[1]._clear_to_zero_allocation() + else: + hs_2d_dispatched._clear_to_zero_allocation() + expert_out_tmp = expert_out.reshape([-1, expert_out.shape[-1]]) + + expert_out_zipped = self.zip_node.forward( + expert_out_tmp, + zipped_expertwise_rowmap, + self.dispatched_indices, + unzipped_probs, + total_zipped_tokens=total_zipped_tokens, + num_experts=num_experts, + ) + + expert_out_zipped.stop_gradient = False + return expert_out_zipped + + @paddle.no_grad() + def backward(self, hidden_states_out_grad): + """ + 反向传播函数。 + + Args: + hidden_states_out_grad_fp8 (Tensor): 隐藏状态梯度。 + + Returns: + Tuple[Tensor, Tensor]: 包含两个元素,分别为hs_fp8_dispatched_grad和dispatched_probs_grad。 + - hs_fp8_dispatched_grad (Tensor): 解压后的隐藏状态梯度。 + - dispatched_probs_grad (Tensor): 分发概率梯度。 + + """ + # zip_grad + unzipped_grad = self.zip_node.backward( + hidden_states_out_grad, + self.dispatched_indices, + self.dispatched_probs, + top_k=self.router_topk, + num_experts=len(self.tokens_per_expert), + tokens_per_expert=self.tokens_per_expert, + ) + record_stream_for_multi_input(hidden_states_out_grad) + + total_zipped_tokens = extract_first_if_tuple(hidden_states_out_grad).shape[0] + total_unzipped_tokens = extract_first_if_tuple(unzipped_grad).shape[0] + hidden_states_size = extract_first_if_tuple(hidden_states_out_grad).shape[-1] + num_experts = len(self.tokens_per_expert) + padding_token_per_experts = [(x + 127) // 128 * 128 for x in self.tokens_per_expert] + + if self.mlp_bwd_subbatch_rows != 0 and total_unzipped_tokens > self.mlp_bwd_subbatch_rows * 2: + map_unzipped_indices_to_zipped = TDU.tokens_unzip_slice( + extract_first_if_tuple(hidden_states_out_grad), + self.unzip_node.zipped_expertwise_rowmap, + num_experts, + total_unzipped_tokens, + 0, + total_unzipped_tokens + 1, + ) + if isinstance(hidden_states_out_grad, tuple): + hidden_states_out_grad[0]._clear_to_zero_allocation() + hidden_states_out_grad[1]._clear_to_zero_allocation() + else: + hidden_states_out_grad._clear_to_zero_allocation() + + subbatch_rows = min((total_unzipped_tokens // num_experts) // 128 * 128, self.mlp_bwd_subbatch_rows) + nparts = (total_unzipped_tokens + subbatch_rows - 1) // subbatch_rows + output = paddle.empty([0, hidden_states_size], dtype=paddle.float32) + probs_grad_list = [] + self.prepare_env_subbatch(is_fwd=False) + logger.info( + f"Enable subbatch_backward!! total_zipped_tokens:{total_zipped_tokens}, total_unzipped_tokens:{total_unzipped_tokens}, nparts:{nparts}, subbatch_rows:{subbatch_rows}, output_sub_rows:{self.output_subbatch_rows}" + ) + for i in range(nparts): + reset_status = True if i == nparts - 1 else False # release saved status in the last part. + start_idx = i * subbatch_rows + end_idx = min(start_idx + subbatch_rows, total_unzipped_tokens) + output, probs_grad = self.gemm_backward_subbatch( + unzipped_grad, + map_unzipped_indices_to_zipped, + total_zipped_tokens, + output, + padding_token_per_experts, + start_idx=start_idx, + end_idx=end_idx, + output_subbatch_rows=self.output_subbatch_rows, + reset_status=reset_status, + ) + probs_grad_list.append(probs_grad) + if isinstance(unzipped_grad, tuple): + unzipped_grad[0]._clear_to_zero_allocation() + unzipped_grad[1]._clear_to_zero_allocation() + else: + unzipped_grad._clear_to_zero_allocation() + hs_dispatched_grad = merge_subbatch_cast(output, paddle.bfloat16) + dispatched_probs_grad = TDU.tokens_zip_prob_seq_subbatch( + probs_grad_list, self.unzip_node.zipped_expertwise_rowmap, self.dispatched_indices, subbatch_rows + ) + self.reset_statue() + return hs_dispatched_grad, dispatched_probs_grad + + if isinstance(hidden_states_out_grad, tuple): + hidden_states_out_grad[0]._clear_to_zero_allocation() + hidden_states_out_grad[1]._clear_to_zero_allocation() + else: + hidden_states_out_grad._clear_to_zero_allocation() + + # expert_grad + expert_out, probs_grad = self.experts_group_gemm_node.backward( + unzipped_grad, self.unzipped_probs, padding_token_per_experts + ) + + hs_dispatched_grad, dispatched_probs_grad = self.unzip_node.backward( + expert_out, + total_zipped_tokens, + probs_grad, + self.dispatched_indices, + num_experts=num_experts, + ) + + self.reset_statue() + return hs_dispatched_grad, dispatched_probs_grad + + +class FusionMoeNode: + def __init__( + self, + custom_map, + recompute_fwd_gate_up=False, + is_split_group_gemm=True, + mlp_fwd_subbatch_rows=0, + mlp_bwd_subbatch_rows=0, + output_subbatch_rows=0, + name="fusion_moe_node", + ): + self.token_dispatcher = custom_map.token_dispatcher + self.moe_router_topk = custom_map.moe_router_topk + self.dispatch_quant_node = Fp8DispatchQuantNode(self.token_dispatcher) + self.dispatch_node = Fp8DispatchNode(self.token_dispatcher) + self.mlp_node = FusionMlpNode( + custom_map, + self.moe_router_topk, + recompute_fwd_gate_up=recompute_fwd_gate_up, + is_split_group_gemm=is_split_group_gemm, + mlp_fwd_subbatch_rows=mlp_fwd_subbatch_rows, + mlp_bwd_subbatch_rows=mlp_bwd_subbatch_rows, + output_subbatch_rows=output_subbatch_rows, + ) + self.combine_node = Fp8CombineNode(self.token_dispatcher) + self.combine_quant_node = Fp8CombineQuantNode(self.token_dispatcher, custom_map.moe_group) + self.name = name + + @paddle.no_grad() + def forward(self, hidden_states, probs, routing_map): + if DSV3_USE_FP8_DISPATCH: + (hs_fp8, hs_scale), token_indices, token_probs = self.dispatch_quant_node.forward( + hidden_states, probs, routing_map + ) + ( + (hs_fp8_dispatched, hs_scale_dispatched), + dispatched_indices, + dispatched_probs, + ) = self.dispatch_node.forward((hs_fp8, hs_scale), token_indices, token_probs) + hidden_states_out = self.mlp_node.forward( + (hs_fp8_dispatched, hs_scale_dispatched), dispatched_indices, dispatched_probs + ) + output_combine = self.combine_node.forward(hidden_states_out) + output = self.combine_quant_node.forward(output_combine) + output.stop_gradient = False + return output + else: + hs_2d_dispatched, dispatched_indices, dispatched_probs = self.dispatch_node.forward( + hidden_states, probs, routing_map + ) + hidden_states_out = self.mlp_node.forward(hs_2d_dispatched, dispatched_indices, dispatched_probs) + output_combine = self.combine_node.forward(hidden_states_out) + output = self.combine_quant_node.forward(output_combine) + output.stop_gradient = False + return output + + @paddle.no_grad() + def backward(self, output_grad): + output_combine_grad, _ = self.combine_quant_node.backward(output_grad) + hidden_states_out_grad = self.combine_node.backward(output_combine_grad) + + hs_dispatched_grad, dispatched_probs_grad = self.mlp_node.backward(hidden_states_out_grad) + + if DSV3_USE_FP8_DISPATCH: + hs_fp8_grad, token_probs_grad = self.dispatch_node.backward(hs_dispatched_grad, dispatched_probs_grad) + hs_grad, probs_grad, routing_map_grad = self.dispatch_quant_node.backward(hs_fp8_grad, token_probs_grad) + return hs_grad, probs_grad, routing_map_grad + else: + hs_bf16_grad, token_probs_grad = self.dispatch_node.backward(hs_dispatched_grad, dispatched_probs_grad) + return hs_bf16_grad, None, token_probs_grad + + +class FusionMoe(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx, + hidden_states, + probs, + routing_map, + custom_map, + recompute_fwd_gate_up=False, + is_split_group_gemm=True, + ): + ctx.node = FusionMoeNode( + custom_map, + recompute_fwd_gate_up=recompute_fwd_gate_up, + is_split_group_gemm=is_split_group_gemm, + ) + return ctx.node.forward(hidden_states, probs, routing_map) + + @staticmethod + def backward(ctx, output_grad): + return ctx.node.backward(output_grad) class MoEFlexTokenLayer(nn.Layer): def __init__(self, config, moe_num_experts, expert_class, expert_kwargs, gate, moe_group): diff --git a/paddleformers/transformers/moe_utils.py b/paddleformers/transformers/moe_utils.py index 466591b0638..4c5f3390e86 100644 --- a/paddleformers/transformers/moe_utils.py +++ b/paddleformers/transformers/moe_utils.py @@ -17,6 +17,51 @@ from typing import Optional import paddle +import numpy as np +import TokenDispatcherUtils as TDU + +from .fp8_utils import FP8LinearFunctionBase + +if not hasattr(paddle.Tensor, "_clear_to_zero_allocation"): + + def _clear_to_zero_allocation(self): + """ + _clear_to_zero_allocation + """ + old_shape = self.shape + dst = paddle.empty([0], dtype=self.dtype) + dst_t = dst.value().get_tensor() + src_t = self.value().get_tensor() + src_t._share_data_with(dst_t) + src_t._set_dims(old_shape) + + setattr(paddle.Tensor, "_clear_to_zero_allocation", _clear_to_zero_allocation) + + +if not hasattr(paddle.Tensor, "_holder_size"): + + def _holder_size(self): + """ + _holder_size + """ + if self._is_initialized(): + return int(np.prod(self.shape)) * paddle.core.size_of_dtype(self.dtype) + else: + return 0 + + setattr(paddle.Tensor, "_holder_size", _holder_size) + + +def topk_to_permuted_indices(x, num_tokens_per_expert_list, topk): + x = paddle.flatten(x) + prob_permuted_indices = paddle.concat( + [ + paddle.tensor.search._restrict_nonzero(x == i, total_true_num) + for i, total_true_num in enumerate(num_tokens_per_expert_list) + ] + ).flatten() + token_permuted_indices = prob_permuted_indices // topk + return token_permuted_indices, prob_permuted_indices def permute( @@ -99,3 +144,340 @@ def unpermute( include_self=True, ) return output_tokens + +class UnZipNode: + def __init__(self, name="unzip"): + self.name = name + self.unzipped_probs = None + self.zipped_expertwise_rowmap = None + + def reset_statue(self): + self.unzipped_probs = None + self.zipped_expertwise_rowmap = None + + @paddle.no_grad() + def forward( + self, + hs_2d_dispatched, + dispatched_indices, + dispatched_probs, + topk, + num_experts, + tokens_per_expert, + ): + if isinstance(hs_2d_dispatched, tuple): + with paddle.amp.auto_cast(False): + ( + unzipped_tokens, + zipped_expertwise_rowmap, + unzipped_probs, + unzipped_scale, + ) = paddle.nn.functional.moe_permute( + hs_2d_dispatched[0], + hs_2d_dispatched[1], + dispatched_indices, + dispatched_probs, + num_experts=num_experts, + tokens_per_expert=tokens_per_expert, + padding_alignment=128, + ) + else: + with paddle.amp.auto_cast(False): + ( + unzipped_tokens, + zipped_expertwise_rowmap, + unzipped_probs, + unzipped_scale, + ) = paddle.nn.functional.moe_permute( + hs_2d_dispatched, + None, + dispatched_indices, + dispatched_probs, + num_experts=num_experts, + tokens_per_expert=tokens_per_expert, + padding_alignment=128, + ) + self.unzipped_probs = unzipped_probs + self.zipped_expertwise_rowmap = zipped_expertwise_rowmap + return (unzipped_tokens, zipped_expertwise_rowmap, unzipped_probs, unzipped_scale) + + @paddle.no_grad() + def backward(self, dx, total_zipped_tokens, probs_grad, dispatched_indices, num_experts): + with paddle.amp.auto_cast(False): + weighted_zipped_tokens, probs_grad_zipped = paddle.nn.functional.moe_unpermute( + dx, + self.zipped_expertwise_rowmap, + dispatched_indices, + probs_grad, + total_zipped_tokens=total_zipped_tokens, + num_experts=num_experts, + ) + self.reset_statue() + return weighted_zipped_tokens, probs_grad_zipped + + +class ZipNode: + def __init__(self, name="zip"): + self.name = name + + @paddle.no_grad() + def forward( + self, expert_out, zipped_expertwise_rowmap, routemap_topk, unzipped_probs, total_zipped_tokens, num_experts + ): + with paddle.amp.auto_cast(False): + expert_out_zipped, zipped_probs_topk = paddle.nn.functional.moe_unpermute( + expert_out, zipped_expertwise_rowmap, routemap_topk, unzipped_probs, total_zipped_tokens, num_experts + ) + return expert_out_zipped + + @paddle.no_grad() + def backward( + self, + grad_output, + dispatched_indices, + dispatched_probs, + top_k, + num_experts, + tokens_per_expert, + ): + if isinstance(grad_output, tuple): + with paddle.amp.auto_cast(False): + ( + unzipped_grad, + zipped_expertwise_rowmap_grad, + unzipped_probs_grad, + unzipped_scale_grad, + ) = paddle.nn.functional.moe_permute( + grad_output[0], + grad_output[1], + dispatched_indices, + dispatched_probs, + num_experts, + tokens_per_expert, + padding_alignment=128, + ) + return (unzipped_grad, unzipped_scale_grad) + else: + with paddle.amp.auto_cast(False): + ( + unzipped_grad, + zipped_expertwise_rowmap_grad, + unzipped_probs_grad, + unzipped_scale_grad, + ) = paddle.nn.functional.moe_permute( + grad_output, + None, + dispatched_indices, + dispatched_probs, + num_experts, + tokens_per_expert, + padding_alignment=128, + ) + + return unzipped_grad + + +class PermuteNode: + def __init__(self, token_dispatcher, name="permute"): + self.token_dispatcher = token_dispatcher + self.name = name + + def reset_status(self): + self.token_permuted_indices = None + self.prob_permuted_indices = None + + def forward(self, hidden_states, hidden_states_scale, dispatched_indices): + self.token_dispatcher._comm_manager.hidden_shape_before_permute = hidden_states.shape + self.hidden_shape_before_permute = hidden_states.shape + self.token_permuted_indices, self.prob_permuted_indices = topk_to_permuted_indices( + dispatched_indices, + self.token_dispatcher._comm_manager.tokens_per_expert, + self.token_dispatcher._comm_manager.router_topk, + ) + hidden_states = permute(hidden_states, self.token_permuted_indices) + # permute scale + hidden_states_scale = permute(hidden_states_scale, self.token_permuted_indices) + + return hidden_states, hidden_states_scale, self.token_permuted_indices, self.prob_permuted_indices + + def backward(self, out_grad, dispatched_probs): + input_dtype = out_grad.dtype + hidden_states_grad = unpermute( + permuted_tokens=out_grad, + token_permuted_indices=self.token_permuted_indices, + prob_permuted_indices=self.prob_permuted_indices, + restore_shape=self.hidden_shape_before_permute, + probs=dispatched_probs, + ) + self.reset_status() + return hidden_states_grad.to(input_dtype) + + +class UnPermuteNode: + def __init__(self, token_dispatcher, name="unpermute"): + self.token_dispatcher = token_dispatcher + self.name = name + + def reset_status(self): + self.token_permuted_indices = None + self.hidden_states = None + self.prob_permuted_indices = None + self.faltten_dispatched_probs = None + self.hidden = None + self.permuted_tokens = None + self.output_tokens = None + + def forward( + self, + hidden_states, + token_permuted_indices, + prob_permuted_indices, + dispatched_probs, + ): + self.token_permuted_indices = token_permuted_indices + self.input_dtype = hidden_states.dtype + self.hidden_states = hidden_states + self.prob_permuted_indices = prob_permuted_indices + self.dispatched_probs_shape = dispatched_probs.shape + # permute + _, self.hidden = self.token_dispatcher._comm_manager.hidden_shape_before_permute + + self.faltten_dispatched_probs = dispatched_probs.flatten() + + self.permuted_probs = paddle.gather(self.faltten_dispatched_probs, self.prob_permuted_indices) + permuted_tokens = self.hidden_states * self.permuted_probs.unsqueeze(-1) + permuted_tokens = permuted_tokens.cast(self.hidden_states.dtype) + + # Create an output tensor filled with zeros + output_tokens = paddle.zeros( + self.token_dispatcher._comm_manager.hidden_shape_before_permute, dtype=self.hidden_states.dtype + ) + # Scatter add the permuted_input back to the original positions + output_tokens.put_along_axis_( + axis=0, + indices=self.token_permuted_indices.cast("int32").unsqueeze(1).expand([-1, self.hidden]), + values=permuted_tokens, + reduce="add", + include_self=True, + ) + with paddle.base.device_guard("cpu"): + self.output_tokens = paddle.empty(shape=output_tokens.shape, dtype=output_tokens.dtype) + + return output_tokens.to(self.input_dtype) + + def backward(self, out_grad, out_grad_scale): + hidden_states_grad = paddle.gather(out_grad, self.token_permuted_indices) + + output_tokens_grad = FP8LinearFunctionBase.dequantize_fp8_to_fp32(out_grad, out_grad_scale) + permuted_tokens = self.hidden_states * self.permuted_probs.unsqueeze(-1) + permuted_tokens = permuted_tokens.cast(self.hidden_states.dtype) + + _, permuted_tokens_grad = paddle._C_ops.put_along_axis_grad( + self.output_tokens, + self.token_permuted_indices.cast("int32").unsqueeze(1).expand([-1, self.hidden]), + permuted_tokens, + self.output_tokens, + output_tokens_grad, + 0, + "add", + True, + ) + + permuted_probs_grad = (permuted_tokens_grad * self.hidden_states).sum(axis=-1) + + faltten_dispatched_probs_grad = paddle._C_ops.gather_grad( + self.faltten_dispatched_probs, self.prob_permuted_indices, permuted_probs_grad, 0 + ) + + # dispatched_probs_grad = paddle._C_ops.flatten_grad(self.dispatched_probs, faltten_dispatched_probs_grad) + dispatched_probs_grad = faltten_dispatched_probs_grad.reshape(self.dispatched_probs_shape) + + self.reset_status() + return hidden_states_grad, dispatched_probs_grad + + +def tokens_zip_unique_add_with_subbatch(zipped, unzipped, index_unzipped, zipped_rows, subbatch_rows=None): + """ + tokens_zip_unique_add_with_subbatch + """ + if subbatch_rows is None or subbatch_rows <= 0 or zipped_rows <= 0: + return TDU.tokens_zip_unique_add(zipped, unzipped, index_unzipped, zipped_rows) + else: + if isinstance(zipped, paddle.Tensor): + num_split = (zipped_rows + subbatch_rows - 1) // subbatch_rows + remainder = zipped_rows % subbatch_rows + if remainder == 0: + rows = [subbatch_rows] * num_split + else: + rows = [subbatch_rows] * (num_split - 1) + [remainder] + + if zipped.shape[0] == 0: + dtype = zipped.dtype + hidden_size = zipped.shape[1] + zipped = [paddle.zeros([r, hidden_size], dtype=dtype) for r in rows] + else: + zipped = paddle.split(zipped, rows, axis=0) + return TDU.tokens_zip_unique_add_subbatch(zipped, unzipped, index_unzipped, zipped_rows, subbatch_rows) + + +def merge_subbatch_cast(x, dtype): + if isinstance(x, (list, tuple)): + if len(x) == 1: + x = x[0] + return x.cast(dtype) if x.dtype != dtype else x + else: + return TDU.merge_subbatch_cast(x, dtype) + else: + return x.cast(dtype) if x.dtype != dtype else x + + +def get_env_device(): + """ + Return the device name of running environment. + """ + if paddle.is_compiled_with_cuda(): + return "gpu" + elif "npu" in paddle.device.get_all_custom_device_type(): + return "npu" + elif "mlu" in paddle.device.get_all_custom_device_type(): + return "mlu" + elif "gcu" in paddle.device.get_all_custom_device_type(): + return "gcu" + elif "intel_hpu" in paddle.device.get_all_custom_device_type(): + return "intel_hpu" + elif paddle.is_compiled_with_rocm(): + return "rocm" + elif paddle.is_compiled_with_xpu(): + return "xpu" + return "cpu" + + +def to_device(tensor, place=None): + if place is None: + place = get_env_device() + + if isinstance(place, str): + place = paddle.device._convert_to_place(place) + + if not tensor.place._equals(place): + new_t = tensor._copy_to(place, True) + dst_tensor = tensor.value().get_tensor() + src_tensor = new_t.value().get_tensor() + dst_tensor._share_data_with(src_tensor) + + return tensor + + +def offload(tensor): + if paddle.is_compiled_with_cuda(): + place = paddle.CUDAPinnedPlace() + else: + place = paddle.CPUPlace() + + new_tensor = to_device(tensor, place) + assert new_tensor is tensor, "to_device must be inplace operation" + + +def reload(tensor): + new_tensor = to_device(tensor) + assert new_tensor is tensor, "to_device must be inplace operation" \ No newline at end of file diff --git a/paddleformers/transformers/token_dispatcher.py b/paddleformers/transformers/token_dispatcher.py index 128f6e52f4d..30a93e7de53 100644 --- a/paddleformers/transformers/token_dispatcher.py +++ b/paddleformers/transformers/token_dispatcher.py @@ -21,7 +21,7 @@ from paddle.distributed.communication.group import Group from .fused_a2a import fused_combine, fused_dispatch -from .moe_utils import permute, unpermute +from .moe_utils import permute, topk_to_permuted_indices, unpermute class _DispatchManager(ABC): @@ -127,7 +127,7 @@ def dispatch(self, hidden_states: paddle.Tensor) -> paddle.Tensor: self.dispatched_indices = states["dispatched_indices"] self.dispatched_probs = dispatched_probs - return hidden_states + return hidden_states, dispatched_indices, dispatched_probs def _indices_to_multihot(self, indices, probs): """ @@ -193,6 +193,34 @@ def get_restored_hidden_states_by_experts(self, hidden_states: paddle.Tensor) -> ) return hidden_states.to(input_dtype) + def get_permuted_hidden_states_by_experts_fast( + self, hidden_states: paddle.Tensor, dispatched_indices: paddle.Tensor + ) -> paddle.Tensor: + self.hidden_shape_before_permute = hidden_states.shape + token_permuted_indices, prob_permuted_indices = topk_to_permuted_indices( + dispatched_indices, self.tokens_per_expert, self.router_topk + ) + hidden_states = permute(hidden_states, token_permuted_indices) + return hidden_states, token_permuted_indices, prob_permuted_indices + + def get_restored_hidden_states_by_experts_fast( + self, + hidden_states: paddle.Tensor, + token_permuted_indices: paddle.Tensor, + prob_permuted_indices: paddle.Tensor, + dispatched_probs: paddle.Tensor, + ) -> paddle.Tensor: + input_dtype = hidden_states.dtype + assert dispatched_probs.dtype == paddle.float32, "DeepEP only supports float32 probs" + hidden_states = unpermute( + permuted_tokens=hidden_states, + token_permuted_indices=token_permuted_indices, + prob_permuted_indices=prob_permuted_indices, + restore_shape=self.hidden_shape_before_permute, + probs=dispatched_probs, + ) + return hidden_states.to(input_dtype) + class MoETokenDispatcher: """ @@ -260,6 +288,34 @@ def __init__(self, num_local_experts: int, moe_router_topk: int, num_moe_experts num_local_experts=self.num_local_experts, ) + def pre_dispatch(self, hidden_states, probs, routing_map): + self.hidden_shape = hidden_states.shape + hidden_states = hidden_states.view([-1, self.hidden_shape[-1]]) + num_tokens = routing_map.shape[0] + routing_map = routing_map.reshape([num_tokens, self._comm_manager.num_experts]) + probs = probs.reshape([num_tokens, self._comm_manager.num_experts]) + # Convert the format of routing map from multihot to indices. + token_probs, token_indices = paddle.topk(probs, self._comm_manager.router_topk, axis=-1) + return hidden_states, token_indices, token_probs + + def post_dispatch(self, hidden_states, dispatched_indices): + ( + global_input_tokens, + token_permuted_indices, + prob_permuted_indices, + ) = self._comm_manager.get_permuted_hidden_states_by_experts_fast(hidden_states, dispatched_indices) + return (global_input_tokens, token_permuted_indices, prob_permuted_indices) + + def pre_combine(self, hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs): + hidden_states = self._comm_manager.get_restored_hidden_states_by_experts_fast( + hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs + ) + return hidden_states + + def post_combine(self, hidden_states): + hidden_states = hidden_states.reshape(self.hidden_shape) + return hidden_states + def token_permutation( self, hidden_states: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor ) -> Tuple[paddle.Tensor, paddle.Tensor]: @@ -267,7 +323,7 @@ def token_permutation( hidden_states = hidden_states.view([-1, self.hidden_shape[-1]]) self._comm_manager.setup_metadata(routing_map, probs) - hidden_states = self._comm_manager.dispatch(hidden_states) + hidden_states, _, _ = self._comm_manager.dispatch(hidden_states) global_input_tokens = self._comm_manager.get_permuted_hidden_states_by_experts(hidden_states) tokens_per_expert = self._comm_manager.get_number_of_tokens_per_expert() @@ -282,3 +338,79 @@ def token_unpermutation( hidden_states = hidden_states.reshape(self.hidden_shape) return hidden_states, None + + def token_permutation_fast( + self, hidden_states: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + hidden_states, token_indices, token_probs = self.pre_dispatch(hidden_states, probs, routing_map) + hidden_states, dispatched_indices, dispatched_probs = self._comm_manager.dispatch( + hidden_states, token_indices, token_probs + ) + (global_input_tokens, token_permuted_indices, prob_permuted_indices) = self.post_dispatch( + hidden_states, dispatched_indices + ) + + return ( + global_input_tokens, + token_permuted_indices, + prob_permuted_indices, + dispatched_probs, + ) + + def token_unpermutation_fast( + self, + hidden_states: paddle.Tensor, + token_permuted_indices, + prob_permuted_indices, + dispatched_probs, + bias: Optional[paddle.Tensor] = None, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]: + assert bias is None, "Bias is not supported in MoEFlexTokenDispatcher" + hidden_states = self.pre_combine( + hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs + ) + hidden_states = self._comm_manager.combine(hidden_states) + + hidden_states = self.post_combine(hidden_states) + return hidden_states, None + + +class PreDispatchNode: + def __init__(self, token_dispatcher): + self.token_dispatcher = token_dispatcher + self.probs_origin_shape = None + + def reset_status(self): + self.probs = None + self.reshaped_probs = None + self.token_indices = None + + @paddle.no_grad() + def forward(self, routing_map, probs): + num_tokens = routing_map.shape[0] + self.probs_origin_shape = probs.shape + # routing_map = routing_map.reshape([num_tokens, token_dispatcher._comm_manager.num_experts]) + self.probs = probs + reshaped_probs = probs.reshape([num_tokens, self.token_dispatcher._comm_manager.num_experts]) + self.reshaped_probs = reshaped_probs + token_probs, token_indices = paddle.topk( + reshaped_probs, self.token_dispatcher._comm_manager.router_topk, axis=-1 + ) + self.token_indices = token_indices + token_probs.stop_gradient = False + return token_indices, token_probs + + @paddle.no_grad() + def backward(self, token_probs_g): + probs_grad = paddle._C_ops.topk_grad( + self.reshaped_probs, + self.token_indices, + token_probs_g, + self.token_dispatcher._comm_manager.router_topk, + -1, + True, + True, + ) + probs_reshape_g = paddle._C_ops.reshape_grad(self.probs, probs_grad) + self.reset_status() + return probs_reshape_g \ No newline at end of file diff --git a/paddleformers/transformers/utils.py b/paddleformers/transformers/utils.py index 83c85fc147f..219e4e1d8b6 100644 --- a/paddleformers/transformers/utils.py +++ b/paddleformers/transformers/utils.py @@ -1005,3 +1005,9 @@ def caculate_llm_per_token_flops( # 2 for mul + add in matmul # 1 for forward, 2 for backwards since we caluate gradients for input_x and input_y return 2 * (layer_num * (flops_per_transformer * 3 + flops_recompute_transformer) + 3 * flops_loggits) / seq_length + +def cast_if_needed(x, dtype): + """ + cast_if_needed + """ + return x.cast(dtype) if x.dtype != dtype else x diff --git a/paddleformers/utils/download/download.py b/paddleformers/utils/download/download.py index bcc2e5bde70..f36e40f4bff 100644 --- a/paddleformers/utils/download/download.py +++ b/paddleformers/utils/download/download.py @@ -44,6 +44,7 @@ class DownloadSource(str, Enum): HUGGINGFACE = "huggingface" AISTUDIO = "aistudio" MODELSCOPE = "modelscope" + BOS = "bos" MODEL_MAPPINGS = {} @@ -64,6 +65,7 @@ def check_repo(model_name_or_path, download_hub): DownloadSource.HUGGINGFACE, DownloadSource.AISTUDIO, DownloadSource.MODELSCOPE, + DownloadSource.BOS, ], f"download_hub must be one of {DownloadSource.HUGGINGFACE}, {DownloadSource.AISTUDIO}, {DownloadSource.MODELSCOPE}" if model_name_or_path not in HF_MODEL_MAPPINGS.keys(): # repo id set by user @@ -87,6 +89,88 @@ def strtobool(v): f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)." ) +from .aistudio_hub_download import ( + aistudio_hub_download, + aistudio_hub_file_exists, + aistudio_hub_try_to_load_from_cache, +) +from .bos_download import bos_download, bos_file_exists, bos_try_to_load_from_cache + + +def bos_aistudio_hf_file_exist( + repo_id: str, + filename: str, + *, + subfolder: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + token: Optional[str] = None, + endpoint: Optional[str] = None, + from_bos: bool = True, + from_aistudio: bool = False, + from_hf_hub: bool = False, +): + assert repo_id is not None, "repo_id cannot be None" + assert filename is not None, "filename cannot be None" + + if subfolder is None: + subfolder = "" + filename = os.path.join(subfolder, filename) + if from_aistudio: + out = aistudio_hub_file_exists( + repo_id=repo_id, + filename=filename, + repo_type=repo_type, + revision=revision, + token=token, + endpoint=endpoint, + ) + elif from_hf_hub: + out = hf_hub_file_exists( + repo_id=repo_id, + filename=filename, + repo_type=repo_type, + revision=revision, + token=token, + ) + else: + out = bos_file_exists( + repo_id=repo_id, + filename=filename, + repo_type=repo_type, + revision=revision, + token=token, # donot need token + endpoint=endpoint, + ) + return out + +def bos_aistudio_hf_try_to_load_from_cache( + repo_id: str, + filename: str, + cache_dir: Union[str, Path, None] = None, + subfolder: str = None, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + from_bos: bool = True, + from_aistudio: bool = False, + from_hf_hub: bool = False, +): + if subfolder is None: + subfolder = "" + load_kwargs = dict( + repo_id=repo_id, + filename=os.path.join(subfolder, filename), + cache_dir=cache_dir, + revision=revision, + repo_type=repo_type, + ) + if from_aistudio: + return aistudio_hub_try_to_load_from_cache(**load_kwargs) + elif from_hf_hub: + return hf_hub_try_to_load_from_cache(**load_kwargs) + else: + return bos_try_to_load_from_cache(**load_kwargs) + def resolve_file_path( repo_id: str = None, @@ -132,7 +216,6 @@ def resolve_file_path( if isinstance(filenames, str): filenames = [filenames] - # check repo id if download_hub is None: download_hub = os.environ.get("DOWNLOAD_SOURCE", "huggingface") @@ -238,6 +321,28 @@ def resolve_file_path( ) if cached_file is not None: return cached_file + else: + log_endpoint = "BOS" + for filename in filenames: + download_kwargs["filename"] = filename + is_available = bos_aistudio_hf_file_exist( + repo_id, + filename, + subfolder=subfolder, + repo_type=repo_type, + revision=revision, + token=token, + endpoint=endpoint, + from_bos=True, + from_aistudio=False, + from_hf_hub=False, + ) + if is_available: + cached_file = bos_download( + **download_kwargs, + ) + if cached_file is not None: + return cached_file except LocalEntryNotFoundError: raise EnvironmentError( "Cannot find the requested files in the cached path and"