|
| 1 | +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | +#include <cassert> |
| 15 | + |
| 16 | +#include "paddle/common/exception.h" |
| 17 | +#include "paddle/phi/core/dense_tensor.h" |
| 18 | +#include "paddle/phi/kernels/empty_kernel.h" |
| 19 | + |
| 20 | +#include "paddle/phi/backends/xpu/enforce_xpu.h" |
| 21 | +#include "paddle/phi/backends/xpu/xpu_context.h" |
| 22 | +#include "paddle/phi/core/kernel_registry.h" |
| 23 | + |
| 24 | +namespace phi { |
| 25 | + |
| 26 | +static void GetRowsCols(const std::vector<int64_t> &shape, |
| 27 | + int64_t *p_rows, |
| 28 | + int64_t *p_cols) { |
| 29 | + int64_t rows = 1; |
| 30 | + for (size_t i = 0; i + 1 < shape.size(); ++i) { |
| 31 | + rows *= shape[i]; |
| 32 | + } |
| 33 | + int64_t cols = shape[shape.size() - 1]; |
| 34 | + *p_rows = rows; |
| 35 | + *p_cols = cols; |
| 36 | +} |
| 37 | + |
| 38 | +template <typename T, typename Context> |
| 39 | +void RMSLnFwd(const Context &dev_ctx, |
| 40 | + const DenseTensor &x, |
| 41 | + const DenseTensor &scale, |
| 42 | + float epsilon, |
| 43 | + DenseTensor *y, |
| 44 | + DenseTensor *invvar) { |
| 45 | + int64_t rows, cols; |
| 46 | + GetRowsCols(common::vectorize(x.dims()), &rows, &cols); |
| 47 | + |
| 48 | + if (scale.dtype() == phi::DataType::BFLOAT16) { |
| 49 | + dev_ctx.template Alloc<phi::bfloat16>(y); |
| 50 | + } else if (scale.dtype() == phi::DataType::FLOAT16) { |
| 51 | + dev_ctx.template Alloc<phi::float16>(y); |
| 52 | + } else if (scale.dtype() == phi::DataType::FLOAT32) { |
| 53 | + dev_ctx.template Alloc<float>(y); |
| 54 | + } else { |
| 55 | + PADDLE_THROW(common::errors::InvalidArgument( |
| 56 | + "The dtype of scale must be FLOAT32, FLOAT16 or BFLOAT16, but got [%s]", |
| 57 | + scale.dtype())); |
| 58 | + } |
| 59 | + invvar->Resize({rows}); |
| 60 | + dev_ctx.template Alloc<float>(invvar); |
| 61 | + |
| 62 | + /* |
| 63 | + refer to: |
| 64 | + - |
| 65 | + https://github.com/NVIDIA/apex/blob/bfb500c8/csrc/layer_norm_cuda_kernel.cu#L1018 |
| 66 | + - |
| 67 | + https://github.com/PaddlePaddle/PaddleNLP/blob/5b9e0b33/ops/csrc/fused_ln/layer_norm_cuda.h#L1087 |
| 68 | +
|
| 69 | + Supported Type combinations: |
| 70 | +
|
| 71 | + input compute scale output |
| 72 | + ======================================= |
| 73 | + fp32 fp32 fp32 fp32 |
| 74 | + fp16 fp32 fp16 fp16 |
| 75 | + bf16 fp32 bf16 bf16 |
| 76 | +
|
| 77 | + Not supported yet: |
| 78 | +
|
| 79 | + input compute scale output |
| 80 | + ======================================= |
| 81 | + fp32 fp32 fp16 fp16 |
| 82 | + fp32 fp32 bf16 bf16 |
| 83 | +
|
| 84 | + Remarks: |
| 85 | + Output type = Scale type |
| 86 | + Compute always in FP32 |
| 87 | + */ |
| 88 | + |
| 89 | +#define DISPATCH_FWD_CASE(scalar_t_out) \ |
| 90 | + using XPUType = typename XPUTypeTrait<scalar_t_out>::Type; \ |
| 91 | + auto ret = xpu::rms_layer_norm<XPUType, XPUType>( \ |
| 92 | + dev_ctx.x_context(), \ |
| 93 | + reinterpret_cast<const XPUType *>(x.data<scalar_t_out>()), \ |
| 94 | + reinterpret_cast<XPUType *>(y->data<scalar_t_out>()), \ |
| 95 | + rows, \ |
| 96 | + cols, \ |
| 97 | + epsilon, \ |
| 98 | + reinterpret_cast<const XPUType *>(scale.data<scalar_t_out>()), \ |
| 99 | + /*bias=*/nullptr, \ |
| 100 | + invvar->data<float>(), \ |
| 101 | + /*is_rstd=*/true); \ |
| 102 | + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "rms_layer_norm"); |
| 103 | + // scale.dtype() same as y->dtype() |
| 104 | + if (x.dtype() == phi::DataType::FLOAT32 && |
| 105 | + scale.dtype() == phi::DataType::FLOAT32) { |
| 106 | + DISPATCH_FWD_CASE(float); |
| 107 | + } else if (x.dtype() == phi::DataType::FLOAT16 && |
| 108 | + scale.dtype() == phi::DataType::FLOAT16) { |
| 109 | + DISPATCH_FWD_CASE(phi::float16); |
| 110 | + } else if (x.dtype() == phi::DataType::BFLOAT16 && |
| 111 | + scale.dtype() == phi::DataType::BFLOAT16) { |
| 112 | + DISPATCH_FWD_CASE(phi::bfloat16); |
| 113 | + } else { |
| 114 | + PADDLE_THROW(common::errors::InvalidArgument( |
| 115 | + "Unsupported dtype combination: x [%s], scale [%s]. " |
| 116 | + "Expected both to be float32, float16, or bfloat16.", |
| 117 | + phi::DataTypeToString(x.dtype()), |
| 118 | + phi::DataTypeToString(scale.dtype()))); |
| 119 | + } |
| 120 | +#undef DISPATCH_FWD_CASE |
| 121 | +} |
| 122 | + |
| 123 | +template <typename T, typename Context> |
| 124 | +void RMSLnBwd(const Context &dev_ctx, |
| 125 | + const DenseTensor &x, |
| 126 | + const DenseTensor &scale, |
| 127 | + const DenseTensor &invvar, |
| 128 | + const DenseTensor &y_grad, |
| 129 | + float epsilon, |
| 130 | + DenseTensor *x_grad, |
| 131 | + DenseTensor *scale_grad) { |
| 132 | + int64_t rows, cols; |
| 133 | + GetRowsCols(common::vectorize(x.dims()), &rows, &cols); |
| 134 | + dev_ctx.template Alloc<T>(x_grad); |
| 135 | + DenseTensor actual_scale_grad; |
| 136 | + if (scale_grad) { |
| 137 | + if (scale.dtype() == phi::DataType::BFLOAT16) { |
| 138 | + dev_ctx.template Alloc<phi::bfloat16>(scale_grad); |
| 139 | + } else if (scale.dtype() == phi::DataType::FLOAT16) { |
| 140 | + dev_ctx.template Alloc<phi::float16>(scale_grad); |
| 141 | + } else if (scale.dtype() == phi::DataType::FLOAT32) { |
| 142 | + dev_ctx.template Alloc<float>(scale_grad); |
| 143 | + } else { |
| 144 | + PADDLE_THROW( |
| 145 | + common::errors::InvalidArgument("The dtype of scale must be FLOAT32, " |
| 146 | + "FLOAT16 or BFLOAT16, but got [%s]", |
| 147 | + scale.dtype())); |
| 148 | + } |
| 149 | + actual_scale_grad = *scale_grad; |
| 150 | + } else { |
| 151 | + // lora specific, scale_grad is nullptr |
| 152 | + if (scale.dtype() == phi::DataType::BFLOAT16) { |
| 153 | + actual_scale_grad = |
| 154 | + phi::EmptyLike<phi::bfloat16, Context>(dev_ctx, scale); |
| 155 | + } else if (scale.dtype() == phi::DataType::FLOAT16) { |
| 156 | + actual_scale_grad = phi::EmptyLike<phi::float16, Context>(dev_ctx, scale); |
| 157 | + } else if (scale.dtype() == phi::DataType::FLOAT32) { |
| 158 | + actual_scale_grad = phi::EmptyLike<float, Context>(dev_ctx, scale); |
| 159 | + } else { |
| 160 | + PADDLE_THROW( |
| 161 | + common::errors::InvalidArgument("The dtype of scale must be FLOAT32, " |
| 162 | + "FLOAT16 or BFLOAT16, but got [%s]", |
| 163 | + scale.dtype())); |
| 164 | + } |
| 165 | + } |
| 166 | + |
| 167 | +#define DISPATCH_BWD_CASE(scalar_t_out) \ |
| 168 | + using XPUType = typename XPUTypeTrait<scalar_t_out>::Type; \ |
| 169 | + auto ret = xpu::rms_layer_norm_grad<XPUType, XPUType>( \ |
| 170 | + dev_ctx.x_context(), \ |
| 171 | + reinterpret_cast<const XPUType *>(x.data<scalar_t_out>()), \ |
| 172 | + reinterpret_cast<const XPUType *>(y_grad.data<scalar_t_out>()), \ |
| 173 | + reinterpret_cast<XPUType *>(x_grad->data<scalar_t_out>()), \ |
| 174 | + rows, \ |
| 175 | + cols, \ |
| 176 | + epsilon, \ |
| 177 | + reinterpret_cast<const XPUType *>(scale.data<scalar_t_out>()), \ |
| 178 | + invvar.data<float>(), \ |
| 179 | + reinterpret_cast<XPUType *>(actual_scale_grad.data<scalar_t_out>()), \ |
| 180 | + /*bias=*/nullptr, \ |
| 181 | + /*is_rstd=*/true); \ |
| 182 | + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "rms_layer_norm_grad"); |
| 183 | + // scale.dtype() same as y->dtype() |
| 184 | + if (x.dtype() == phi::DataType::FLOAT32 && |
| 185 | + scale.dtype() == phi::DataType::FLOAT32) { |
| 186 | + DISPATCH_BWD_CASE(float); |
| 187 | + } else if (x.dtype() == phi::DataType::FLOAT16 && |
| 188 | + scale.dtype() == phi::DataType::FLOAT16) { |
| 189 | + DISPATCH_BWD_CASE(phi::float16); |
| 190 | + } else if (x.dtype() == phi::DataType::BFLOAT16 && |
| 191 | + scale.dtype() == phi::DataType::BFLOAT16) { |
| 192 | + DISPATCH_BWD_CASE(phi::bfloat16); |
| 193 | + } else { |
| 194 | + PADDLE_THROW(common::errors::InvalidArgument( |
| 195 | + "Unsupported dtype combination: x [%s], scale [%s]. " |
| 196 | + "Expected both to be float32, float16, or bfloat16.", |
| 197 | + phi::DataTypeToString(x.dtype()), |
| 198 | + phi::DataTypeToString(scale.dtype()))); |
| 199 | + } |
| 200 | +#undef DISPATCH_BWD_CASE |
| 201 | +} |
| 202 | + |
| 203 | +} // namespace phi |
| 204 | + |
| 205 | +PD_REGISTER_KERNEL(fused_rms_norm_ext, |
| 206 | + XPU, |
| 207 | + ALL_LAYOUT, |
| 208 | + phi::RMSLnFwd, |
| 209 | + float, |
| 210 | + phi::dtype::float16, |
| 211 | + phi::dtype::bfloat16) {} |
| 212 | + |
| 213 | +PD_REGISTER_KERNEL(fused_rms_norm_ext_grad, |
| 214 | + XPU, |
| 215 | + ALL_LAYOUT, |
| 216 | + phi::RMSLnBwd, |
| 217 | + float, |
| 218 | + phi::dtype::float16, |
| 219 | + phi::dtype::bfloat16) {} |
0 commit comments