2424#include " paddle/phi/core/kernel_registry.h"
2525#include " paddle/phi/core/tensor_utils.h"
2626#include " paddle/phi/kernels/complex_kernel.h"
27- #include " paddle/phi/kernels/full_kernel.h"
2827#include " paddle/phi/kernels/funcs/elementwise_base.h"
2928
3029namespace phi {
30+ template <typename YType, typename Context>
31+ void MixedPrecisionAddGradKernel (const Context& dev_ctx,
32+ const DenseTensor& x,
33+ const DenseTensor& y,
34+ const DenseTensor& dout,
35+ int axis,
36+ DenseTensor* dx,
37+ DenseTensor* dy) {
38+ using T = float ;
39+ using XPUType = typename XPUTypeTrait<T>::Type;
40+ using XPUYType = typename XPUTypeTrait<YType>::Type;
41+
42+ if (dout.numel () == 0 ) {
43+ if (dx) {
44+ dev_ctx.template Alloc <T>(dx);
45+ if (dx->numel () > 0 ) {
46+ int ret =
47+ xpu::constant<XPUType>(dev_ctx.x_context (),
48+ reinterpret_cast <XPUType*>(dx->data <T>()),
49+ dx->numel (),
50+ static_cast <XPUType>(0 ));
51+ PADDLE_ENFORCE_XDNN_SUCCESS (ret, " constant" );
52+ }
53+ }
54+ if (dy) {
55+ dev_ctx.template Alloc <YType>(dy);
56+ if (dy->numel () > 0 ) {
57+ int ret = xpu::constant<XPUYType>(
58+ dev_ctx.x_context (),
59+ reinterpret_cast <XPUYType*>(dy->data <YType>()),
60+ dy->numel (),
61+ static_cast <XPUYType>(0 ));
62+ PADDLE_ENFORCE_XDNN_SUCCESS (ret, " constant" );
63+ }
64+ }
65+ return ;
66+ }
67+
68+ funcs::ElementwiseGradPreProcess (dout, dx);
69+ auto * dz = &dout;
70+ const DDim& dz_dims = dz->dims ();
71+ const T* dz_data = dz->data <T>();
72+
73+ if (dx != nullptr ) {
74+ T* dx_data = dev_ctx.template Alloc <T>(dx);
75+ if (dx->dims () == dz_dims) {
76+ if (dx_data != dz_data) {
77+ int ret = xpu::copy (dev_ctx.x_context (),
78+ reinterpret_cast <const XPUType*>(dz_data),
79+ reinterpret_cast <XPUType*>(dx_data),
80+ dx->numel ());
81+ PADDLE_ENFORCE_XDNN_SUCCESS (ret, " copy" );
82+ }
83+ } else {
84+ // For inplace strategy, dx will be stored in addr of dz, which makes
85+ // the result of dy wrong.
86+ if (dx->IsSharedBufferWith (*dz)) {
87+ dx->clear ();
88+ dx->Resize (x.dims ());
89+ dev_ctx.template Alloc <T>(dx);
90+ }
91+ std::vector<int > reduce_dims =
92+ funcs::GetReduceDim (dx->dims (), dz_dims, axis);
93+ std::vector<int64_t > dz_vector = common::vectorize<int64_t >(dz_dims);
94+
95+ int ret = xpu::reduce_sum<XPUType>(
96+ dev_ctx.x_context (),
97+ reinterpret_cast <const XPUType*>(dz_data),
98+ reinterpret_cast <XPUType*>(dx_data),
99+ dz_vector,
100+ std::vector<int64_t >(reduce_dims.begin (), reduce_dims.end ()));
101+ PADDLE_ENFORCE_XDNN_SUCCESS (ret, " reduce_sum" );
102+ }
103+ }
104+
105+ if (dy != nullptr ) {
106+ YType* dy_data = dev_ctx.template Alloc <YType>(dy);
107+ if (dy->dims () == dz_dims) {
108+ int ret = xpu::cast<XPUType, XPUYType>(
109+ dev_ctx.x_context (),
110+ reinterpret_cast <const XPUType*>(dz_data),
111+ reinterpret_cast <XPUYType*>(dy_data),
112+ dout.numel ());
113+ PADDLE_ENFORCE_XDNN_SUCCESS (ret, " cast" );
114+ } else {
115+ std::vector<int > reduce_dims =
116+ funcs::GetReduceDim (dy->dims (), dz_dims, axis);
117+ std::vector<int64_t > dz_vector = common::vectorize<int64_t >(dz_dims);
118+
119+ DenseTensor casted_dz;
120+ casted_dz.Resize (dz_dims);
121+ YType* casted_dz_data = dev_ctx.template Alloc <YType>(&casted_dz);
122+
123+ int ret_cast = xpu::cast<XPUType, XPUYType>(
124+ dev_ctx.x_context (),
125+ reinterpret_cast <const XPUType*>(dz_data),
126+ reinterpret_cast <XPUYType*>(casted_dz_data),
127+ dout.numel ());
128+ PADDLE_ENFORCE_XDNN_SUCCESS (ret_cast, " cast" );
129+
130+ int ret_reduce = xpu::reduce_sum<XPUYType>(
131+ dev_ctx.x_context (),
132+ reinterpret_cast <const XPUYType*>(casted_dz_data),
133+ reinterpret_cast <XPUYType*>(dy_data),
134+ dz_vector,
135+ std::vector<int64_t >(reduce_dims.begin (), reduce_dims.end ()));
136+ PADDLE_ENFORCE_XDNN_SUCCESS (ret_reduce, " reduce_sum" );
137+ }
138+ }
139+ }
140+
31141template <typename T, typename Context>
32142void AddGradKernel (const Context& dev_ctx,
33143 const DenseTensor& x,
@@ -36,30 +146,50 @@ void AddGradKernel(const Context& dev_ctx,
36146 int axis,
37147 DenseTensor* dx,
38148 DenseTensor* dy) {
149+ // special case for "float32 + bfloat16", or "float32 + float16"
150+ if (x.dtype () == DataType::FLOAT32) {
151+ if (y.dtype () == DataType::FLOAT16) {
152+ MixedPrecisionAddGradKernel<phi::float16>(
153+ dev_ctx, x, y, dout, axis, dx, dy);
154+ return ;
155+ }
156+ if (y.dtype () == DataType::BFLOAT16) {
157+ MixedPrecisionAddGradKernel<phi::bfloat16>(
158+ dev_ctx, x, y, dout, axis, dx, dy);
159+ return ;
160+ }
161+ }
162+
163+ using XPUType = typename XPUTypeTrait<T>::Type;
39164 if (dout.numel () == 0 ) {
40165 if (dx) {
41- if (dx->numel () == 0 ) {
42- dev_ctx.template Alloc <T>(dx);
43- } else {
44- phi::Full<T, Context>(
45- dev_ctx, phi::IntArray (common::vectorize (dx->dims ())), 0 , dx);
166+ dev_ctx.template Alloc <T>(dx);
167+ if (dx->numel () > 0 ) {
168+ int ret =
169+ xpu::constant<XPUType>(dev_ctx.x_context (),
170+ reinterpret_cast <XPUType*>(dx->data <T>()),
171+ dx->numel (),
172+ static_cast <XPUType>(0 ));
173+ PADDLE_ENFORCE_XDNN_SUCCESS (ret, " constant" );
46174 }
47175 }
48176 if (dy) {
49- if (dy->numel () == 0 ) {
50- dev_ctx.template Alloc <T>(dy);
51- } else {
52- phi::Full<T, Context>(
53- dev_ctx, phi::IntArray (common::vectorize (dy->dims ())), 0 , dy);
177+ dev_ctx.template Alloc <T>(dy);
178+ if (dy->numel () > 0 ) {
179+ int ret =
180+ xpu::constant<XPUType>(dev_ctx.x_context (),
181+ reinterpret_cast <XPUType*>(dy->data <T>()),
182+ dy->numel (),
183+ static_cast <XPUType>(0 ));
184+ PADDLE_ENFORCE_XDNN_SUCCESS (ret, " constant" );
54185 }
55186 }
56187 return ;
57188 }
58- using XPUType = typename XPUTypeTrait<T>::Type;
189+
59190 funcs::ElementwiseGradPreProcess (dout, dx);
60191 auto * dz = &dout;
61192 const DDim& dz_dims = dz->dims ();
62-
63193 const T* dz_data = dz->data <T>();
64194
65195 if (dx != nullptr ) {
@@ -68,7 +198,7 @@ void AddGradKernel(const Context& dev_ctx,
68198 if (dx_data != dz_data) {
69199 int ret = xpu::copy (dev_ctx.x_context (),
70200 reinterpret_cast <const XPUType*>(dz_data),
71- reinterpret_cast <XPUType*>(dx-> data <T>() ),
201+ reinterpret_cast <XPUType*>(dx_data ),
72202 dx->numel ());
73203 PADDLE_ENFORCE_XDNN_SUCCESS (ret, " copy" );
74204 }
@@ -87,7 +217,7 @@ void AddGradKernel(const Context& dev_ctx,
87217 int ret = xpu::reduce_sum<XPUType>(
88218 dev_ctx.x_context (),
89219 reinterpret_cast <const XPUType*>(dz_data),
90- reinterpret_cast <XPUType*>(dx-> data <T>() ),
220+ reinterpret_cast <XPUType*>(dx_data ),
91221 dz_vector,
92222 std::vector<int64_t >(reduce_dims.begin (), reduce_dims.end ()));
93223 PADDLE_ENFORCE_XDNN_SUCCESS (ret, " reduce_sum" );
@@ -100,7 +230,7 @@ void AddGradKernel(const Context& dev_ctx,
100230 if (dy_data != dz_data) {
101231 int ret = xpu::copy (dev_ctx.x_context (),
102232 reinterpret_cast <const XPUType*>(dz_data),
103- reinterpret_cast <XPUType*>(dy-> data <T>() ),
233+ reinterpret_cast <XPUType*>(dy_data ),
104234 dy->numel ());
105235 PADDLE_ENFORCE_XDNN_SUCCESS (ret, " copy" );
106236 }
@@ -118,6 +248,7 @@ void AddGradKernel(const Context& dev_ctx,
118248 }
119249 }
120250}
251+
121252#ifdef PADDLE_WITH_XPU_FFT
122253template <>
123254void AddGradKernel<phi::complex64, XPUContext>(const XPUContext& dev_ctx,
0 commit comments