Skip to content

Commit b5efb98

Browse files
Fix MultiPrecisionAdd 0size (PaddlePaddle#76512)
* fix MultiPrecisionAdd 0size * add 0size test * fix add grad mixed precision in xpu * add head * fix * fix * add MixedPrecisionAddGradKernel * refine * fix
1 parent 43f16a6 commit b5efb98

File tree

4 files changed

+190
-24
lines changed

4 files changed

+190
-24
lines changed

paddle/phi/kernels/kps/elementwise_kernel.cu

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,17 +95,21 @@ void AddKernel(const Context& dev_ctx,
9595
const DenseTensor& x,
9696
const DenseTensor& y,
9797
DenseTensor* out) {
98-
if (x.numel() == 0 || y.numel() == 0) {
99-
dev_ctx.template Alloc<T>(out);
100-
return;
101-
}
10298
#ifdef PADDLE_WITH_CUDA
10399
if (x.dtype() == DataType::FLOAT32 &&
104100
(y.dtype() == DataType::FLOAT16 || y.dtype() == DataType::BFLOAT16)) {
101+
if (x.numel() == 0 || y.numel() == 0) {
102+
dev_ctx.template Alloc<float>(out);
103+
return;
104+
}
105105
MultiPrecisionAddKernelImpl<float, Context>(dev_ctx, x, y, out);
106106
return;
107107
}
108108
#endif
109+
if (x.numel() == 0 || y.numel() == 0) {
110+
dev_ctx.template Alloc<T>(out);
111+
return;
112+
}
109113
phi::AddRawKernel<T, Context>(dev_ctx, x, y, -1, out);
110114
}
111115

paddle/phi/kernels/xpu/elementwise_add_grad_kernel.cc

Lines changed: 147 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,120 @@
2424
#include "paddle/phi/core/kernel_registry.h"
2525
#include "paddle/phi/core/tensor_utils.h"
2626
#include "paddle/phi/kernels/complex_kernel.h"
27-
#include "paddle/phi/kernels/full_kernel.h"
2827
#include "paddle/phi/kernels/funcs/elementwise_base.h"
2928

3029
namespace phi {
30+
template <typename YType, typename Context>
31+
void MixedPrecisionAddGradKernel(const Context& dev_ctx,
32+
const DenseTensor& x,
33+
const DenseTensor& y,
34+
const DenseTensor& dout,
35+
int axis,
36+
DenseTensor* dx,
37+
DenseTensor* dy) {
38+
using T = float;
39+
using XPUType = typename XPUTypeTrait<T>::Type;
40+
using XPUYType = typename XPUTypeTrait<YType>::Type;
41+
42+
if (dout.numel() == 0) {
43+
if (dx) {
44+
dev_ctx.template Alloc<T>(dx);
45+
if (dx->numel() > 0) {
46+
int ret =
47+
xpu::constant<XPUType>(dev_ctx.x_context(),
48+
reinterpret_cast<XPUType*>(dx->data<T>()),
49+
dx->numel(),
50+
static_cast<XPUType>(0));
51+
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "constant");
52+
}
53+
}
54+
if (dy) {
55+
dev_ctx.template Alloc<YType>(dy);
56+
if (dy->numel() > 0) {
57+
int ret = xpu::constant<XPUYType>(
58+
dev_ctx.x_context(),
59+
reinterpret_cast<XPUYType*>(dy->data<YType>()),
60+
dy->numel(),
61+
static_cast<XPUYType>(0));
62+
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "constant");
63+
}
64+
}
65+
return;
66+
}
67+
68+
funcs::ElementwiseGradPreProcess(dout, dx);
69+
auto* dz = &dout;
70+
const DDim& dz_dims = dz->dims();
71+
const T* dz_data = dz->data<T>();
72+
73+
if (dx != nullptr) {
74+
T* dx_data = dev_ctx.template Alloc<T>(dx);
75+
if (dx->dims() == dz_dims) {
76+
if (dx_data != dz_data) {
77+
int ret = xpu::copy(dev_ctx.x_context(),
78+
reinterpret_cast<const XPUType*>(dz_data),
79+
reinterpret_cast<XPUType*>(dx_data),
80+
dx->numel());
81+
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "copy");
82+
}
83+
} else {
84+
// For inplace strategy, dx will be stored in addr of dz, which makes
85+
// the result of dy wrong.
86+
if (dx->IsSharedBufferWith(*dz)) {
87+
dx->clear();
88+
dx->Resize(x.dims());
89+
dev_ctx.template Alloc<T>(dx);
90+
}
91+
std::vector<int> reduce_dims =
92+
funcs::GetReduceDim(dx->dims(), dz_dims, axis);
93+
std::vector<int64_t> dz_vector = common::vectorize<int64_t>(dz_dims);
94+
95+
int ret = xpu::reduce_sum<XPUType>(
96+
dev_ctx.x_context(),
97+
reinterpret_cast<const XPUType*>(dz_data),
98+
reinterpret_cast<XPUType*>(dx_data),
99+
dz_vector,
100+
std::vector<int64_t>(reduce_dims.begin(), reduce_dims.end()));
101+
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum");
102+
}
103+
}
104+
105+
if (dy != nullptr) {
106+
YType* dy_data = dev_ctx.template Alloc<YType>(dy);
107+
if (dy->dims() == dz_dims) {
108+
int ret = xpu::cast<XPUType, XPUYType>(
109+
dev_ctx.x_context(),
110+
reinterpret_cast<const XPUType*>(dz_data),
111+
reinterpret_cast<XPUYType*>(dy_data),
112+
dout.numel());
113+
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "cast");
114+
} else {
115+
std::vector<int> reduce_dims =
116+
funcs::GetReduceDim(dy->dims(), dz_dims, axis);
117+
std::vector<int64_t> dz_vector = common::vectorize<int64_t>(dz_dims);
118+
119+
DenseTensor casted_dz;
120+
casted_dz.Resize(dz_dims);
121+
YType* casted_dz_data = dev_ctx.template Alloc<YType>(&casted_dz);
122+
123+
int ret_cast = xpu::cast<XPUType, XPUYType>(
124+
dev_ctx.x_context(),
125+
reinterpret_cast<const XPUType*>(dz_data),
126+
reinterpret_cast<XPUYType*>(casted_dz_data),
127+
dout.numel());
128+
PADDLE_ENFORCE_XDNN_SUCCESS(ret_cast, "cast");
129+
130+
int ret_reduce = xpu::reduce_sum<XPUYType>(
131+
dev_ctx.x_context(),
132+
reinterpret_cast<const XPUYType*>(casted_dz_data),
133+
reinterpret_cast<XPUYType*>(dy_data),
134+
dz_vector,
135+
std::vector<int64_t>(reduce_dims.begin(), reduce_dims.end()));
136+
PADDLE_ENFORCE_XDNN_SUCCESS(ret_reduce, "reduce_sum");
137+
}
138+
}
139+
}
140+
31141
template <typename T, typename Context>
32142
void AddGradKernel(const Context& dev_ctx,
33143
const DenseTensor& x,
@@ -36,30 +146,50 @@ void AddGradKernel(const Context& dev_ctx,
36146
int axis,
37147
DenseTensor* dx,
38148
DenseTensor* dy) {
149+
// special case for "float32 + bfloat16", or "float32 + float16"
150+
if (x.dtype() == DataType::FLOAT32) {
151+
if (y.dtype() == DataType::FLOAT16) {
152+
MixedPrecisionAddGradKernel<phi::float16>(
153+
dev_ctx, x, y, dout, axis, dx, dy);
154+
return;
155+
}
156+
if (y.dtype() == DataType::BFLOAT16) {
157+
MixedPrecisionAddGradKernel<phi::bfloat16>(
158+
dev_ctx, x, y, dout, axis, dx, dy);
159+
return;
160+
}
161+
}
162+
163+
using XPUType = typename XPUTypeTrait<T>::Type;
39164
if (dout.numel() == 0) {
40165
if (dx) {
41-
if (dx->numel() == 0) {
42-
dev_ctx.template Alloc<T>(dx);
43-
} else {
44-
phi::Full<T, Context>(
45-
dev_ctx, phi::IntArray(common::vectorize(dx->dims())), 0, dx);
166+
dev_ctx.template Alloc<T>(dx);
167+
if (dx->numel() > 0) {
168+
int ret =
169+
xpu::constant<XPUType>(dev_ctx.x_context(),
170+
reinterpret_cast<XPUType*>(dx->data<T>()),
171+
dx->numel(),
172+
static_cast<XPUType>(0));
173+
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "constant");
46174
}
47175
}
48176
if (dy) {
49-
if (dy->numel() == 0) {
50-
dev_ctx.template Alloc<T>(dy);
51-
} else {
52-
phi::Full<T, Context>(
53-
dev_ctx, phi::IntArray(common::vectorize(dy->dims())), 0, dy);
177+
dev_ctx.template Alloc<T>(dy);
178+
if (dy->numel() > 0) {
179+
int ret =
180+
xpu::constant<XPUType>(dev_ctx.x_context(),
181+
reinterpret_cast<XPUType*>(dy->data<T>()),
182+
dy->numel(),
183+
static_cast<XPUType>(0));
184+
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "constant");
54185
}
55186
}
56187
return;
57188
}
58-
using XPUType = typename XPUTypeTrait<T>::Type;
189+
59190
funcs::ElementwiseGradPreProcess(dout, dx);
60191
auto* dz = &dout;
61192
const DDim& dz_dims = dz->dims();
62-
63193
const T* dz_data = dz->data<T>();
64194

65195
if (dx != nullptr) {
@@ -68,7 +198,7 @@ void AddGradKernel(const Context& dev_ctx,
68198
if (dx_data != dz_data) {
69199
int ret = xpu::copy(dev_ctx.x_context(),
70200
reinterpret_cast<const XPUType*>(dz_data),
71-
reinterpret_cast<XPUType*>(dx->data<T>()),
201+
reinterpret_cast<XPUType*>(dx_data),
72202
dx->numel());
73203
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "copy");
74204
}
@@ -87,7 +217,7 @@ void AddGradKernel(const Context& dev_ctx,
87217
int ret = xpu::reduce_sum<XPUType>(
88218
dev_ctx.x_context(),
89219
reinterpret_cast<const XPUType*>(dz_data),
90-
reinterpret_cast<XPUType*>(dx->data<T>()),
220+
reinterpret_cast<XPUType*>(dx_data),
91221
dz_vector,
92222
std::vector<int64_t>(reduce_dims.begin(), reduce_dims.end()));
93223
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum");
@@ -100,7 +230,7 @@ void AddGradKernel(const Context& dev_ctx,
100230
if (dy_data != dz_data) {
101231
int ret = xpu::copy(dev_ctx.x_context(),
102232
reinterpret_cast<const XPUType*>(dz_data),
103-
reinterpret_cast<XPUType*>(dy->data<T>()),
233+
reinterpret_cast<XPUType*>(dy_data),
104234
dy->numel());
105235
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "copy");
106236
}
@@ -118,6 +248,7 @@ void AddGradKernel(const Context& dev_ctx,
118248
}
119249
}
120250
}
251+
121252
#ifdef PADDLE_WITH_XPU_FFT
122253
template <>
123254
void AddGradKernel<phi::complex64, XPUContext>(const XPUContext& dev_ctx,

paddle/phi/kernels/xpu/elementwise_add_kernel.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ void AddKernel(const Context& dev_ctx,
3535
const DenseTensor& x,
3636
const DenseTensor& y,
3737
DenseTensor* out) {
38-
if (out->numel() == 0) {
39-
dev_ctx.template Alloc<T>(out);
40-
return;
41-
}
4238
if (x.dtype() == phi::DataType::FLOAT32 &&
4339
(y.dtype() == phi::DataType::BFLOAT16 ||
4440
y.dtype() == phi::DataType::FLOAT16)) {
4541
// special case for "float32 + bfloat16", or "float32 + float16"
42+
if (out->numel() == 0) {
43+
dev_ctx.template Alloc<float>(out);
44+
return;
45+
}
4646
auto dev_version =
4747
phi::backends::xpu::get_xpu_version(dev_ctx.GetPlace().GetDeviceId());
4848
if (dev_version >= phi::backends::xpu::XPUVersion::XPU3 &&
@@ -82,6 +82,10 @@ void AddKernel(const Context& dev_ctx,
8282
XPUElementwise<Type, XPUType>(dev_ctx, x, casted_y, -1, out, f);
8383
}
8484
} else {
85+
if (out->numel() == 0) {
86+
dev_ctx.template Alloc<T>(out);
87+
return;
88+
}
8589
using XPUType = typename XPUTypeTrait<T>::Type;
8690

8791
auto f = [](xpu::Context* xpu_ctx,

test/legacy_test/test_add_op.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,33 @@
2020
from paddle.base import core
2121

2222

23+
class TestPaddleAddZeroSize(unittest.TestCase):
24+
def setUp(self):
25+
self.place = get_device_place()
26+
self.shape = [0, 3]
27+
self.dtype_pairs = [(paddle.float32, paddle.float32)]
28+
if core.is_float16_supported(self.place):
29+
self.dtype_pairs.append((paddle.float32, paddle.float16))
30+
if core.is_bfloat16_supported(self.place):
31+
self.dtype_pairs.append((paddle.float32, paddle.bfloat16))
32+
33+
def test_0size(self):
34+
for x_dtype, y_dtype in self.dtype_pairs:
35+
with self.subTest(msg=f"{x_dtype} + {y_dtype}"):
36+
x = paddle.randn(self.shape, dtype=x_dtype)
37+
y = paddle.randn(self.shape, dtype=y_dtype)
38+
x.stop_gradient = False
39+
y.stop_gradient = False
40+
41+
out = paddle.add(x, y)
42+
out.backward()
43+
44+
self.assertEqual(out.shape, self.shape)
45+
self.assertEqual(out.dtype, x_dtype)
46+
self.assertEqual(x.grad.dtype, x_dtype)
47+
self.assertEqual(y.grad.dtype, y_dtype)
48+
49+
2350
class TestPaddleAddBackward(unittest.TestCase):
2451
def setUp(self):
2552
self.place = get_device_place()

0 commit comments

Comments
 (0)