Skip to content

Commit 1b69f8f

Browse files
authored
[XPU] support fused_rms_norm_ext (#74286)
* [XPU] support fused_rms_norm_ext * fix
1 parent 8e85d72 commit 1b69f8f

File tree

7 files changed

+437
-16
lines changed

7 files changed

+437
-16
lines changed

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1433,7 +1433,8 @@ XPUOpMap& get_kl3_ops() {
14331433
phi::DataType::FLOAT16,
14341434
phi::DataType::BFLOAT16,
14351435
})},
1436-
{"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})},
1436+
{"sqrt_grad",
1437+
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
14371438
{"square_grad",
14381439
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
14391440
{"square",
@@ -1902,6 +1903,14 @@ XPUOpMap& get_kl3_ops() {
19021903
XPUKernelSet({phi::DataType::FLOAT32,
19031904
phi::DataType::FLOAT16,
19041905
phi::DataType::BFLOAT16})},
1906+
{"fused_rms_norm_ext",
1907+
XPUKernelSet({phi::DataType::FLOAT32,
1908+
phi::DataType::FLOAT16,
1909+
phi::DataType::BFLOAT16})},
1910+
{"fused_rms_norm_ext_grad",
1911+
XPUKernelSet({phi::DataType::FLOAT32,
1912+
phi::DataType::FLOAT16,
1913+
phi::DataType::BFLOAT16})},
19051914
#ifdef PADDLE_WITH_XPU_FFT
19061915
{"conj",
19071916
XPUKernelSet({phi::DataType::FLOAT32,

paddle/phi/infermeta/backward.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2119,17 +2119,20 @@ void FusedRMSNormGradInferMeta(const MetaTensor& x,
21192119
MetaTensor* x_grad,
21202120
MetaTensor* scale_grad) {
21212121
PADDLE_ENFORCE_EQ(
2122-
x.dtype() == DataType::FLOAT32 || x.dtype() == DataType::BFLOAT16,
2122+
x.dtype() == DataType::FLOAT32 || x.dtype() == DataType::FLOAT16 ||
2123+
x.dtype() == DataType::BFLOAT16,
21232124
true,
21242125
common::errors::InvalidArgument(
2125-
"The dtype of x must be FLOAT32 or BFLOAT16, but got [%s]",
2126+
"The dtype of x must be FLOAT32, FLOAT16 or BFLOAT16, but got [%s]",
21262127
x.dtype()));
21272128
PADDLE_ENFORCE_EQ(
2128-
scale.dtype() == DataType::FLOAT32 || scale.dtype() == DataType::BFLOAT16,
2129+
scale.dtype() == DataType::FLOAT32 ||
2130+
scale.dtype() == DataType::FLOAT16 ||
2131+
scale.dtype() == DataType::BFLOAT16,
21292132
true,
2130-
common::errors::InvalidArgument(
2131-
"The dtype of scale must be FLOAT32 or BFLOAT16, but got [%s]",
2132-
scale.dtype()));
2133+
common::errors::InvalidArgument("The dtype of scale must be FLOAT32, "
2134+
"FLOAT16 or BFLOAT16, but got [%s]",
2135+
scale.dtype()));
21332136
if (x_grad && x) {
21342137
x_grad->share_meta(x);
21352138
}

paddle/phi/infermeta/binary.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4694,17 +4694,20 @@ void FusedRMSNormInferMeta(const MetaTensor& x,
46944694
scale_shape[0],
46954695
x_shape[x_shape.size() - 1]));
46964696
PADDLE_ENFORCE_EQ(
4697-
x.dtype() == DataType::FLOAT32 || x.dtype() == DataType::BFLOAT16,
4697+
x.dtype() == DataType::FLOAT32 || x.dtype() == DataType::FLOAT16 ||
4698+
x.dtype() == DataType::BFLOAT16,
46984699
true,
46994700
common::errors::InvalidArgument(
4700-
"The dtype of x must be FLOAT32 or BFLOAT16, but got [%s]",
4701+
"The dtype of x must be FLOAT32, FLOAT16 or BFLOAT16, but got [%s]",
47014702
x.dtype()));
47024703
PADDLE_ENFORCE_EQ(
4703-
scale.dtype() == DataType::FLOAT32 || scale.dtype() == DataType::BFLOAT16,
4704+
scale.dtype() == DataType::FLOAT32 ||
4705+
scale.dtype() == DataType::FLOAT16 ||
4706+
scale.dtype() == DataType::BFLOAT16,
47044707
true,
4705-
common::errors::InvalidArgument(
4706-
"The dtype of scale must be FLOAT32 or BFLOAT16, but got [%s]",
4707-
scale.dtype()));
4708+
common::errors::InvalidArgument("The dtype of scale must be FLOAT32, "
4709+
"FLOAT16 or BFLOAT16, but got [%s]",
4710+
scale.dtype()));
47084711

47094712
y->set_dims(x.dims());
47104713
y->set_dtype(scale.dtype());

paddle/phi/kernels/legacy/gpu/layer_norm_cuda_kernel.cu

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,15 @@ void RMSLnFwd(const Context &dev_ctx,
4545
const auto &scale_shape = scale.dims();
4646
int rows, cols;
4747
GetRowsCols(common::vectorize(x.dims()), &rows, &cols);
48-
if (scale.dtype() == phi::DataType::BFLOAT16)
48+
if (scale.dtype() == phi::DataType::BFLOAT16) {
4949
dev_ctx.template Alloc<phi::bfloat16>(y);
50-
else if (scale.dtype() == phi::DataType::FLOAT32)
50+
} else if (scale.dtype() == phi::DataType::FLOAT32) {
5151
dev_ctx.template Alloc<float>(y);
52+
} else {
53+
PADDLE_THROW(common::errors::InvalidArgument(
54+
"The dtype of scale must be FLOAT32, BFLOAT16, but got [%s]",
55+
scale.dtype()));
56+
}
5257
invvar->Resize({rows});
5358
dev_ctx.template Alloc<float>(invvar);
5459
cuda_rms_norm<T, Context>(dev_ctx, x, scale, rows, cols, epsilon, y, invvar);
@@ -71,6 +76,10 @@ void RMSLnBwd(const Context &dev_ctx,
7176
dev_ctx.template Alloc<phi::bfloat16>(scale_grad);
7277
} else if (scale.dtype() == phi::DataType::FLOAT32) {
7378
dev_ctx.template Alloc<float>(scale_grad);
79+
} else {
80+
PADDLE_THROW(common::errors::InvalidArgument(
81+
"The dtype of scale must be FLOAT32, BFLOAT16, but got [%s]",
82+
scale.dtype()));
7483
}
7584
cuda_rms_norm_gradient<T, Context>(dev_ctx,
7685
x,
@@ -110,6 +119,10 @@ void RMSLnBwd(const Context &dev_ctx,
110119
epsilon,
111120
x_grad,
112121
&scale_grad_tmp);
122+
} else {
123+
PADDLE_THROW(common::errors::InvalidArgument(
124+
"The dtype of scale must be FLOAT32, BFLOAT16, but got [%s]",
125+
scale.dtype()));
113126
}
114127
}
115128
}

paddle/phi/kernels/xpu/activation_grad_kernel.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -802,12 +802,18 @@ PD_REGISTER_KERNEL(rsqrt_grad,
802802
phi::dtype::float16,
803803
phi::dtype::bfloat16) {}
804804

805+
PD_REGISTER_KERNEL(sqrt_grad,
806+
XPU,
807+
ALL_LAYOUT,
808+
phi::SqrtGradKernel,
809+
float,
810+
phi::dtype::float16) {}
811+
805812
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel)
806813
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
807814
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardsigmoid_grad, HardSigmoidGradKernel)
808815
PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel)
809816
PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel)
810-
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
811817
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
812818
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel)
813819
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sin_grad, SinGradKernel)
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#include <cassert>
15+
16+
#include "paddle/common/exception.h"
17+
#include "paddle/phi/core/dense_tensor.h"
18+
#include "paddle/phi/kernels/empty_kernel.h"
19+
20+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
21+
#include "paddle/phi/backends/xpu/xpu_context.h"
22+
#include "paddle/phi/core/kernel_registry.h"
23+
24+
namespace phi {
25+
26+
static void GetRowsCols(const std::vector<int64_t> &shape,
27+
int64_t *p_rows,
28+
int64_t *p_cols) {
29+
int64_t rows = 1;
30+
for (size_t i = 0; i + 1 < shape.size(); ++i) {
31+
rows *= shape[i];
32+
}
33+
int64_t cols = shape[shape.size() - 1];
34+
*p_rows = rows;
35+
*p_cols = cols;
36+
}
37+
38+
template <typename T, typename Context>
39+
void RMSLnFwd(const Context &dev_ctx,
40+
const DenseTensor &x,
41+
const DenseTensor &scale,
42+
float epsilon,
43+
DenseTensor *y,
44+
DenseTensor *invvar) {
45+
int64_t rows, cols;
46+
GetRowsCols(common::vectorize(x.dims()), &rows, &cols);
47+
48+
if (scale.dtype() == phi::DataType::BFLOAT16) {
49+
dev_ctx.template Alloc<phi::bfloat16>(y);
50+
} else if (scale.dtype() == phi::DataType::FLOAT16) {
51+
dev_ctx.template Alloc<phi::float16>(y);
52+
} else if (scale.dtype() == phi::DataType::FLOAT32) {
53+
dev_ctx.template Alloc<float>(y);
54+
} else {
55+
PADDLE_THROW(common::errors::InvalidArgument(
56+
"The dtype of scale must be FLOAT32, FLOAT16 or BFLOAT16, but got [%s]",
57+
scale.dtype()));
58+
}
59+
invvar->Resize({rows});
60+
dev_ctx.template Alloc<float>(invvar);
61+
62+
/*
63+
refer to:
64+
-
65+
https://github.com/NVIDIA/apex/blob/bfb500c8/csrc/layer_norm_cuda_kernel.cu#L1018
66+
-
67+
https://github.com/PaddlePaddle/PaddleNLP/blob/5b9e0b33/ops/csrc/fused_ln/layer_norm_cuda.h#L1087
68+
69+
Supported Type combinations:
70+
71+
input compute scale output
72+
=======================================
73+
fp32 fp32 fp32 fp32
74+
fp16 fp32 fp16 fp16
75+
bf16 fp32 bf16 bf16
76+
77+
Not supported yet:
78+
79+
input compute scale output
80+
=======================================
81+
fp32 fp32 fp16 fp16
82+
fp32 fp32 bf16 bf16
83+
84+
Remarks:
85+
Output type = Scale type
86+
Compute always in FP32
87+
*/
88+
89+
#define DISPATCH_FWD_CASE(scalar_t_out) \
90+
using XPUType = typename XPUTypeTrait<scalar_t_out>::Type; \
91+
auto ret = xpu::rms_layer_norm<XPUType, XPUType>( \
92+
dev_ctx.x_context(), \
93+
reinterpret_cast<const XPUType *>(x.data<scalar_t_out>()), \
94+
reinterpret_cast<XPUType *>(y->data<scalar_t_out>()), \
95+
rows, \
96+
cols, \
97+
epsilon, \
98+
reinterpret_cast<const XPUType *>(scale.data<scalar_t_out>()), \
99+
/*bias=*/nullptr, \
100+
invvar->data<float>(), \
101+
/*is_rstd=*/true); \
102+
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "rms_layer_norm");
103+
// scale.dtype() same as y->dtype()
104+
if (x.dtype() == phi::DataType::FLOAT32 &&
105+
scale.dtype() == phi::DataType::FLOAT32) {
106+
DISPATCH_FWD_CASE(float);
107+
} else if (x.dtype() == phi::DataType::FLOAT16 &&
108+
scale.dtype() == phi::DataType::FLOAT16) {
109+
DISPATCH_FWD_CASE(phi::float16);
110+
} else if (x.dtype() == phi::DataType::BFLOAT16 &&
111+
scale.dtype() == phi::DataType::BFLOAT16) {
112+
DISPATCH_FWD_CASE(phi::bfloat16);
113+
} else {
114+
PADDLE_THROW(common::errors::InvalidArgument(
115+
"Unsupported dtype combination: x [%s], scale [%s]. "
116+
"Expected both to be float32, float16, or bfloat16.",
117+
phi::DataTypeToString(x.dtype()),
118+
phi::DataTypeToString(scale.dtype())));
119+
}
120+
#undef DISPATCH_FWD_CASE
121+
}
122+
123+
template <typename T, typename Context>
124+
void RMSLnBwd(const Context &dev_ctx,
125+
const DenseTensor &x,
126+
const DenseTensor &scale,
127+
const DenseTensor &invvar,
128+
const DenseTensor &y_grad,
129+
float epsilon,
130+
DenseTensor *x_grad,
131+
DenseTensor *scale_grad) {
132+
int64_t rows, cols;
133+
GetRowsCols(common::vectorize(x.dims()), &rows, &cols);
134+
dev_ctx.template Alloc<T>(x_grad);
135+
DenseTensor actual_scale_grad;
136+
if (scale_grad) {
137+
if (scale.dtype() == phi::DataType::BFLOAT16) {
138+
dev_ctx.template Alloc<phi::bfloat16>(scale_grad);
139+
} else if (scale.dtype() == phi::DataType::FLOAT16) {
140+
dev_ctx.template Alloc<phi::float16>(scale_grad);
141+
} else if (scale.dtype() == phi::DataType::FLOAT32) {
142+
dev_ctx.template Alloc<float>(scale_grad);
143+
} else {
144+
PADDLE_THROW(
145+
common::errors::InvalidArgument("The dtype of scale must be FLOAT32, "
146+
"FLOAT16 or BFLOAT16, but got [%s]",
147+
scale.dtype()));
148+
}
149+
actual_scale_grad = *scale_grad;
150+
} else {
151+
// lora specific, scale_grad is nullptr
152+
if (scale.dtype() == phi::DataType::BFLOAT16) {
153+
actual_scale_grad =
154+
phi::EmptyLike<phi::bfloat16, Context>(dev_ctx, scale);
155+
} else if (scale.dtype() == phi::DataType::FLOAT16) {
156+
actual_scale_grad = phi::EmptyLike<phi::float16, Context>(dev_ctx, scale);
157+
} else if (scale.dtype() == phi::DataType::FLOAT32) {
158+
actual_scale_grad = phi::EmptyLike<float, Context>(dev_ctx, scale);
159+
} else {
160+
PADDLE_THROW(
161+
common::errors::InvalidArgument("The dtype of scale must be FLOAT32, "
162+
"FLOAT16 or BFLOAT16, but got [%s]",
163+
scale.dtype()));
164+
}
165+
}
166+
167+
#define DISPATCH_BWD_CASE(scalar_t_out) \
168+
using XPUType = typename XPUTypeTrait<scalar_t_out>::Type; \
169+
auto ret = xpu::rms_layer_norm_grad<XPUType, XPUType>( \
170+
dev_ctx.x_context(), \
171+
reinterpret_cast<const XPUType *>(x.data<scalar_t_out>()), \
172+
reinterpret_cast<const XPUType *>(y_grad.data<scalar_t_out>()), \
173+
reinterpret_cast<XPUType *>(x_grad->data<scalar_t_out>()), \
174+
rows, \
175+
cols, \
176+
epsilon, \
177+
reinterpret_cast<const XPUType *>(scale.data<scalar_t_out>()), \
178+
invvar.data<float>(), \
179+
reinterpret_cast<XPUType *>(actual_scale_grad.data<scalar_t_out>()), \
180+
/*bias=*/nullptr, \
181+
/*is_rstd=*/true); \
182+
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "rms_layer_norm_grad");
183+
// scale.dtype() same as y->dtype()
184+
if (x.dtype() == phi::DataType::FLOAT32 &&
185+
scale.dtype() == phi::DataType::FLOAT32) {
186+
DISPATCH_BWD_CASE(float);
187+
} else if (x.dtype() == phi::DataType::FLOAT16 &&
188+
scale.dtype() == phi::DataType::FLOAT16) {
189+
DISPATCH_BWD_CASE(phi::float16);
190+
} else if (x.dtype() == phi::DataType::BFLOAT16 &&
191+
scale.dtype() == phi::DataType::BFLOAT16) {
192+
DISPATCH_BWD_CASE(phi::bfloat16);
193+
} else {
194+
PADDLE_THROW(common::errors::InvalidArgument(
195+
"Unsupported dtype combination: x [%s], scale [%s]. "
196+
"Expected both to be float32, float16, or bfloat16.",
197+
phi::DataTypeToString(x.dtype()),
198+
phi::DataTypeToString(scale.dtype())));
199+
}
200+
#undef DISPATCH_BWD_CASE
201+
}
202+
203+
} // namespace phi
204+
205+
PD_REGISTER_KERNEL(fused_rms_norm_ext,
206+
XPU,
207+
ALL_LAYOUT,
208+
phi::RMSLnFwd,
209+
float,
210+
phi::dtype::float16,
211+
phi::dtype::bfloat16) {}
212+
213+
PD_REGISTER_KERNEL(fused_rms_norm_ext_grad,
214+
XPU,
215+
ALL_LAYOUT,
216+
phi::RMSLnBwd,
217+
float,
218+
phi::dtype::float16,
219+
phi::dtype::bfloat16) {}

0 commit comments

Comments
 (0)