Skip to content

Commit b86f58b

Browse files
authored
[Accuracy diff No.148] Fix accuracy diff for paddle.Tensor.tile API (#73454)
* fix accuracy tile_grad * apply review for clean useless code
1 parent 02b3495 commit b86f58b

File tree

1 file changed

+43
-13
lines changed

1 file changed

+43
-13
lines changed

paddle/phi/kernels/impl/tile_grad_kernel_impl.h

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <vector>
1818

1919
#include "paddle/phi/core/tensor_utils.h"
20+
#include "paddle/phi/kernels/cast_kernel.h"
2021
#include "paddle/phi/kernels/funcs/eigen/common.h"
2122
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
2223
#include "paddle/phi/kernels/tile_grad_kernel.h"
@@ -33,20 +34,49 @@ void TileBackward(const Context& dev_ctx,
3334
size_t reduce_size = reduce_dims_vec.size();
3435
dev_ctx.template Alloc<T>(x_grad);
3536

36-
auto eigen_x_grad = EigenVector<T>::Flatten(*x_grad);
37-
Eigen::DSizes<Eigen::DenseIndex, Dims * 2> reshape_dims;
38-
for (size_t i = 0; i < reshape_size; ++i) {
39-
reshape_dims[i] = reshape_dims_vec[i];
40-
}
41-
Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
42-
for (size_t i = 0; i < reduce_size; ++i) {
43-
reduce_dims[i] = reduce_dims_vec[i];
44-
}
37+
if constexpr (std::is_same_v<T, dtype::float16> ||
38+
std::is_same_v<T, dtype::bfloat16>) {
39+
const DenseTensor out_grad_fp32 =
40+
phi::Cast<T, Context>(dev_ctx, out_grad, DataType::FLOAT32);
41+
DenseTensor x_grad_fp32;
42+
x_grad_fp32.Resize(x_grad->dims());
43+
dev_ctx.template Alloc<float>(&x_grad_fp32);
44+
auto eigen_x_grad = EigenVector<float>::Flatten(x_grad_fp32);
45+
Eigen::DSizes<Eigen::DenseIndex, Dims * 2> reshape_dims;
46+
for (size_t i = 0; i < reshape_size; ++i) {
47+
reshape_dims[i] = reshape_dims_vec[i];
48+
}
49+
Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
50+
for (size_t i = 0; i < reduce_size; ++i) {
51+
reduce_dims[i] = reduce_dims_vec[i];
52+
}
53+
const auto eigen_out_grad_fp32 = EigenVector<float>::Flatten(out_grad_fp32);
54+
auto& place = *dev_ctx.eigen_device();
55+
funcs::EigenBroadcastGrad<std::decay_t<decltype(place)>, float, Dims>::Eval(
56+
place, eigen_x_grad, eigen_out_grad_fp32, reduce_dims, reshape_dims);
57+
if constexpr (std::is_same_v<T, dtype::float16>) {
58+
phi::CastKernel<float, Context>(
59+
dev_ctx, x_grad_fp32, DataType::FLOAT16, x_grad);
60+
} else {
61+
phi::CastKernel<float, Context>(
62+
dev_ctx, x_grad_fp32, DataType::BFLOAT16, x_grad);
63+
}
64+
} else {
65+
auto eigen_x_grad = EigenVector<T>::Flatten(*x_grad);
66+
Eigen::DSizes<Eigen::DenseIndex, Dims * 2> reshape_dims;
67+
for (size_t i = 0; i < reshape_size; ++i) {
68+
reshape_dims[i] = reshape_dims_vec[i];
69+
}
70+
Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
71+
for (size_t i = 0; i < reduce_size; ++i) {
72+
reduce_dims[i] = reduce_dims_vec[i];
73+
}
4574

46-
auto eigen_out_grad = EigenVector<T>::Flatten(out_grad);
47-
auto& place = *dev_ctx.eigen_device();
48-
funcs::EigenBroadcastGrad<std::decay_t<decltype(place)>, T, Dims>::Eval(
49-
place, eigen_x_grad, eigen_out_grad, reduce_dims, reshape_dims);
75+
auto eigen_out_grad = EigenVector<T>::Flatten(out_grad);
76+
auto& place = *dev_ctx.eigen_device();
77+
funcs::EigenBroadcastGrad<std::decay_t<decltype(place)>, T, Dims>::Eval(
78+
place, eigen_x_grad, eigen_out_grad, reduce_dims, reshape_dims);
79+
}
5080
}
5181

5282
template <typename T, typename Context>

0 commit comments

Comments
 (0)