Skip to content

Commit 20f736a

Browse files
authored
Adding fast_ln & fast_rms_norm and their backward op (PaddlePaddle#76360)
* finished main architecture, start optest and kernel refractor * fix issues, add docs * add optest * fix optest * Add optest & compile bypass * recover diff * test=document_fix
1 parent 9c62a8a commit 20f736a

26 files changed

+3469
-0
lines changed

paddle/phi/infermeta/backward.cc

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2273,6 +2273,70 @@ void FusedRMSNormGradInferMeta(const MetaTensor& x,
22732273
}
22742274
}
22752275

2276+
PADDLE_API void FastLayerNormGradInfermeta(const MetaTensor& x,
2277+
const MetaTensor& scale,
2278+
const MetaTensor& mean,
2279+
const MetaTensor& invvar,
2280+
const MetaTensor& y_grad,
2281+
float epsilon,
2282+
MetaTensor* x_grad,
2283+
MetaTensor* scale_grad,
2284+
MetaTensor* bias_grad) {
2285+
PADDLE_ENFORCE_EQ(
2286+
x.dtype() == DataType::FLOAT32 || x.dtype() == DataType::FLOAT16 ||
2287+
x.dtype() == DataType::BFLOAT16,
2288+
true,
2289+
common::errors::InvalidArgument(
2290+
"The dtype of x must be FLOAT32, FLOAT16 or BFLOAT16, but got [%s]",
2291+
x.dtype()));
2292+
PADDLE_ENFORCE_EQ(
2293+
scale.dtype() == DataType::FLOAT32 ||
2294+
scale.dtype() == DataType::FLOAT16 ||
2295+
scale.dtype() == DataType::BFLOAT16,
2296+
true,
2297+
common::errors::InvalidArgument("The dtype of scale must be FLOAT32, "
2298+
"FLOAT16 or BFLOAT16, but got [%s]",
2299+
scale.dtype()));
2300+
if (x_grad && x) {
2301+
x_grad->share_meta(x);
2302+
}
2303+
if (scale_grad && scale) {
2304+
scale_grad->share_meta(scale);
2305+
}
2306+
if (bias_grad) {
2307+
bias_grad->share_meta(scale);
2308+
}
2309+
}
2310+
2311+
PADDLE_API void FastRMSNormGradInfermeta(const MetaTensor& x,
2312+
const MetaTensor& scale,
2313+
const MetaTensor& invvar,
2314+
const MetaTensor& y_grad,
2315+
float epsilon,
2316+
MetaTensor* x_grad,
2317+
MetaTensor* scale_grad) {
2318+
PADDLE_ENFORCE_EQ(
2319+
x.dtype() == DataType::FLOAT32 || x.dtype() == DataType::FLOAT16 ||
2320+
x.dtype() == DataType::BFLOAT16,
2321+
true,
2322+
common::errors::InvalidArgument(
2323+
"The dtype of x must be FLOAT32, FLOAT16 or BFLOAT16, but got [%s]",
2324+
x.dtype()));
2325+
PADDLE_ENFORCE_EQ(
2326+
scale.dtype() == DataType::FLOAT32 ||
2327+
scale.dtype() == DataType::FLOAT16 ||
2328+
scale.dtype() == DataType::BFLOAT16,
2329+
true,
2330+
common::errors::InvalidArgument("The dtype of scale must be FLOAT32, "
2331+
"FLOAT16 or BFLOAT16, but got [%s]",
2332+
scale.dtype()));
2333+
if (x_grad && x) {
2334+
x_grad->share_meta(x);
2335+
}
2336+
if (scale_grad && scale) {
2337+
scale_grad->share_meta(scale);
2338+
}
2339+
}
22762340
void IndexElementwiseGetGradInferMeta(
22772341
const MetaTensor& x,
22782342
const std::vector<const MetaTensor*>& index,

paddle/phi/infermeta/backward.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,4 +842,22 @@ PADDLE_API void IndexElementwiseGetGradInferMeta(
842842
const bool accumulate,
843843
const bool is_combined,
844844
MetaTensor* x_grad);
845+
846+
PADDLE_API void FastLayerNormGradInfermeta(const MetaTensor& x,
847+
const MetaTensor& scale,
848+
const MetaTensor& mean,
849+
const MetaTensor& invvar,
850+
const MetaTensor& y_grad,
851+
float epsilon,
852+
MetaTensor* x_grad,
853+
MetaTensor* scale_grad,
854+
MetaTensor* bias_grad);
855+
856+
PADDLE_API void FastRMSNormGradInfermeta(const MetaTensor& x,
857+
const MetaTensor& scale,
858+
const MetaTensor& invvar,
859+
const MetaTensor& y_grad,
860+
float epsilon,
861+
MetaTensor* x_grad,
862+
MetaTensor* scale_grad);
845863
} // namespace phi

paddle/phi/infermeta/binary.cc

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1845,6 +1845,56 @@ void ExpandAsInferMeta(const MetaTensor& x,
18451845
#undef MAX_RANK_SUPPORTED
18461846
}
18471847

1848+
void FastRMSNormInfermeta(const MetaTensor& x,
1849+
const MetaTensor& scale,
1850+
float epsilon,
1851+
MetaTensor* y,
1852+
MetaTensor* invvar) {
1853+
auto x_dim = x.dims();
1854+
auto x_ndim = x_dim.size();
1855+
1856+
auto matrix_dim = common::flatten_to_2d(x_dim, x_ndim - 1);
1857+
1858+
int64_t right = matrix_dim[1];
1859+
if (scale) {
1860+
PADDLE_ENFORCE_EQ(scale.dims().size(),
1861+
1,
1862+
common::errors::InvalidArgument(
1863+
"The dimensions of Input(Scale) must be 1, but "
1864+
"received dimensions of "
1865+
"Input(Scale) is [%d]",
1866+
scale.dims().size()));
1867+
}
1868+
1869+
PADDLE_ENFORCE_EQ(
1870+
scale.dims()[0],
1871+
right,
1872+
common::errors::InvalidArgument(
1873+
"The first dimension value of Input(Scale) must equal to be the "
1874+
"second dimension value of the flattened 2D matrix of Input(X), "
1875+
"But received the first dimension value of Input(Scale) is "
1876+
"[%d], the second dimension value of the flattened 2D matrix of "
1877+
" Input(Scale) is [%d].",
1878+
scale.dims()[0],
1879+
right));
1880+
1881+
PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f,
1882+
true,
1883+
common::errors::InvalidArgument(
1884+
"'epsilon' in Op(LayerNorm) should be between"
1885+
"0.0 and 0.001, But received [%s].",
1886+
epsilon));
1887+
1888+
phi::DataType x_dtype = x.dtype();
1889+
phi::DataType scale_dtype = scale.dtype();
1890+
y->set_dims(x_dim);
1891+
y->set_dtype(scale_dtype);
1892+
1893+
auto row_shape = slice_ddim(x_dim, 0, x_dim.size() - 1);
1894+
invvar->set_dims({row_shape});
1895+
invvar->set_dtype(paddle::DataType::FLOAT32);
1896+
}
1897+
18481898
void FakeDequantizeMaxAbsInferMeta(const MetaTensor& x,
18491899
const MetaTensor& scale,
18501900
float max_range,

paddle/phi/infermeta/binary.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,12 @@ PADDLE_API void ExpandAsInferMeta(const MetaTensor& x,
359359
const std::vector<int64_t>& target_shape,
360360
MetaTensor* out);
361361

362+
PADDLE_API void FastRMSNormInfermeta(const MetaTensor& x,
363+
const MetaTensor& scale,
364+
float epsilon,
365+
MetaTensor* y,
366+
MetaTensor* invvar);
367+
362368
PADDLE_API void FakeDequantizeMaxAbsInferMeta(const MetaTensor& x,
363369
const MetaTensor& scale,
364370
float max_range,

paddle/phi/infermeta/ternary.cc

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,79 @@ void DpsgdInferMeta(const MetaTensor& param,
604604
param_out->set_dims(param_dims);
605605
}
606606

607+
void FastLayerNormInfermeta(const MetaTensor& x,
608+
const MetaTensor& scale,
609+
const MetaTensor& bias,
610+
float epsilon,
611+
MetaTensor* y,
612+
MetaTensor* mean,
613+
MetaTensor* invvar) {
614+
auto x_dim = x.dims();
615+
auto x_ndim = x_dim.size();
616+
617+
auto matrix_dim = common::flatten_to_2d(x_dim, x_ndim - 1);
618+
619+
int64_t right = matrix_dim[1];
620+
if (scale) {
621+
PADDLE_ENFORCE_EQ(scale.dims().size(),
622+
1,
623+
common::errors::InvalidArgument(
624+
"The dimensions of Input(Scale) must be 1, but "
625+
"received dimensions of "
626+
"Input(Scale) is [%d]",
627+
scale.dims().size()));
628+
}
629+
630+
PADDLE_ENFORCE_EQ(
631+
scale.dims()[0],
632+
right,
633+
common::errors::InvalidArgument(
634+
"The first dimension value of Input(Scale) must equal to be the "
635+
"second dimension value of the flattened 2D matrix of Input(X), "
636+
"But received the first dimension value of Input(Scale) is "
637+
"[%d], the second dimension value of the flattened 2D matrix of "
638+
" Input(Scale) is [%d].",
639+
scale.dims()[0],
640+
right));
641+
if (bias) {
642+
PADDLE_ENFORCE_EQ(bias.dims().size(),
643+
1,
644+
common::errors::InvalidArgument(
645+
"The dimensions of Input(Bias) must be 1, but "
646+
"received dimensions of "
647+
"Input(Bias) is [%d]",
648+
bias.dims().size()));
649+
}
650+
PADDLE_ENFORCE_EQ(
651+
bias.dims()[0],
652+
right,
653+
common::errors::InvalidArgument(
654+
"The first dimension value of Input(Bias) must equal to be the "
655+
"second dimension value of the flattened 2D matrix of Input(X), "
656+
"But received the first dimension value of Input(Bias) is "
657+
"[%d], the second dimension value of the flattened 2D matrix of "
658+
" Input(Bias) is [%d].",
659+
bias.dims()[0],
660+
right));
661+
662+
PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f,
663+
true,
664+
common::errors::InvalidArgument(
665+
"'epsilon' in Op(LayerNorm) should be between"
666+
"0.0 and 0.001, But received [%s].",
667+
epsilon));
668+
669+
phi::DataType x_dtype = x.dtype();
670+
phi::DataType scale_dtype = scale.dtype();
671+
y->set_dims(x_dim);
672+
y->set_dtype(scale_dtype);
673+
674+
auto row_shape = slice_ddim(x_dim, 0, x_dim.size() - 1);
675+
mean->set_dims({row_shape});
676+
mean->set_dtype(paddle::DataType::FLOAT32);
677+
invvar->set_dims({row_shape});
678+
invvar->set_dtype(paddle::DataType::FLOAT32);
679+
}
607680
void FakeQuantizeRangeAbsMaxInferMeta(const MetaTensor& x,
608681
const MetaTensor& in_scale,
609682
const MetaTensor& iter,

paddle/phi/infermeta/ternary.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,14 @@ PADDLE_API void FlashAttnV3InferMeta(const MetaTensor& q,
179179
MetaTensor* out,
180180
MetaTensor* softmax_lse);
181181

182+
PADDLE_API void FastLayerNormInfermeta(const MetaTensor& x,
183+
const MetaTensor& scale,
184+
const MetaTensor& bias,
185+
float epsilon,
186+
MetaTensor* y,
187+
MetaTensor* mean,
188+
MetaTensor* invvar);
189+
182190
PADDLE_API void FlashAttnV3VarlenInferMeta(const MetaTensor& q,
183191
const MetaTensor& k,
184192
const MetaTensor& v,

paddle/phi/kernels/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ if(((WITH_GPU) AND (CUDA_VERSION VERSION_LESS 12.0))
7676
"legacy/gpu/int_bincount.cu"
7777
"legacy/gpu/fp8_gemm_blockwise_kernel.cu"
7878
"legacy/gpu/fp8_quant_blockwise_kernel.cu"
79+
"legacy/gpu/fast_layernorm_kernel.cu"
80+
"legacy/gpu/fast_layernorm_grad_kernel.cu"
81+
"legacy/gpu/fast_rmsnorm_kernel.cu"
82+
"legacy/gpu/fast_rmsnorm_grad_kernel.cu"
83+
"legacy/gpu/ln.cu"
84+
"legacy/gpu/ln_bwd_semi_cuda_kernel.cu"
85+
"legacy/gpu/ln_fwd_cuda_kernel.cu"
7986
"fusion/gpu/fused_act_dequant_kernel.cu"
8087
"fusion/gpu/fused_stack_transpose_quant_kernel.cu"
8188
"fusion/gpu/fused_transpose_split_quant_kernel.cu"
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
16+
17+
/*This code is copied from NVIDIA apex:
18+
* https://github.com/NVIDIA/apex
19+
* with minor changes. */
20+
21+
#include "ln.h" // NOLINT
22+
#include "paddle/phi/core/dense_tensor.h"
23+
#include "paddle/phi/core/kernel_registry.h"
24+
25+
namespace phi {
26+
27+
template <typename T, typename Context>
28+
void LnBwdKernel(const Context &dev_ctx,
29+
const DenseTensor &x,
30+
const DenseTensor &scale,
31+
const DenseTensor &mean,
32+
const DenseTensor &invvar,
33+
const DenseTensor &y_grad,
34+
float epsilon,
35+
DenseTensor *x_grad,
36+
DenseTensor *scale_grad,
37+
DenseTensor *bias_grad) {
38+
auto input_type = x.type();
39+
auto weight_type = scale.type();
40+
auto output_type = weight_type;
41+
auto compute_type = paddle::DataType::FLOAT32;
42+
43+
PD_CHECK(y_grad.dtype() == output_type);
44+
45+
auto sizes = x.dims();
46+
PD_CHECK(sizes.size() >= 2);
47+
PD_CHECK(y_grad.dims() == sizes);
48+
49+
int64_t rows = 1;
50+
for (size_t i = 0; i + 1 < sizes.size(); ++i) {
51+
rows *= sizes[i];
52+
}
53+
auto cols = sizes[sizes.size() - 1];
54+
55+
auto hidden_size = scale.numel();
56+
57+
PD_CHECK(mean.numel() == rows);
58+
59+
PD_CHECK(mean.dims() == invvar.dims());
60+
61+
PD_CHECK(scale.numel() == cols);
62+
63+
dev_ctx.template Alloc<T>(x_grad);
64+
dev_ctx.template Alloc<T>(scale_grad);
65+
dev_ctx.template Alloc<T>(bias_grad);
66+
67+
auto place = x.place();
68+
69+
LaunchNormBwd<T, Context>(
70+
dev_ctx,
71+
dev_ctx.stream(),
72+
place,
73+
/* x_ptr */ x.data(),
74+
/* scale_ptr */ scale.data(),
75+
/* mean_ptr */ mean.data(),
76+
/* invvar_ptr */ invvar.data(),
77+
/* y_grad_ptr */ y_grad.data(),
78+
/* x_grad_ptr */ x_grad ? x_grad->data() : nullptr,
79+
/* scale_grad_ptr */ scale_grad ? scale_grad->data() : nullptr,
80+
/* bias_grad_ptr */ bias_grad ? bias_grad->data() : nullptr,
81+
weight_type,
82+
input_type,
83+
output_type,
84+
compute_type,
85+
hidden_size,
86+
rows,
87+
cols,
88+
epsilon);
89+
}
90+
} // namespace phi
91+
92+
PD_REGISTER_KERNEL(fast_ln_grad,
93+
GPU,
94+
ALL_LAYOUT,
95+
phi::LnBwdKernel,
96+
float,
97+
double,
98+
phi::float16,
99+
phi::bfloat16) {}

0 commit comments

Comments
 (0)