17
17
#include < vector>
18
18
19
19
#include " paddle/phi/core/tensor_utils.h"
20
+ #include " paddle/phi/kernels/cast_kernel.h"
20
21
#include " paddle/phi/kernels/funcs/eigen/common.h"
21
22
#include " paddle/phi/kernels/funcs/eigen/eigen_function.h"
22
23
#include " paddle/phi/kernels/tile_grad_kernel.h"
@@ -33,20 +34,49 @@ void TileBackward(const Context& dev_ctx,
33
34
size_t reduce_size = reduce_dims_vec.size ();
34
35
dev_ctx.template Alloc <T>(x_grad);
35
36
36
- auto eigen_x_grad = EigenVector<T>::Flatten (*x_grad);
37
- Eigen::DSizes<Eigen::DenseIndex, Dims * 2 > reshape_dims;
38
- for (size_t i = 0 ; i < reshape_size; ++i) {
39
- reshape_dims[i] = reshape_dims_vec[i];
40
- }
41
- Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
42
- for (size_t i = 0 ; i < reduce_size; ++i) {
43
- reduce_dims[i] = reduce_dims_vec[i];
44
- }
37
+ if constexpr (std::is_same_v<T, dtype::float16> ||
38
+ std::is_same_v<T, dtype::bfloat16>) {
39
+ const DenseTensor out_grad_fp32 =
40
+ phi::Cast<T, Context>(dev_ctx, out_grad, DataType::FLOAT32);
41
+ DenseTensor x_grad_fp32;
42
+ x_grad_fp32.Resize (x_grad->dims ());
43
+ dev_ctx.template Alloc <float >(&x_grad_fp32);
44
+ auto eigen_x_grad = EigenVector<float >::Flatten (x_grad_fp32);
45
+ Eigen::DSizes<Eigen::DenseIndex, Dims * 2 > reshape_dims;
46
+ for (size_t i = 0 ; i < reshape_size; ++i) {
47
+ reshape_dims[i] = reshape_dims_vec[i];
48
+ }
49
+ Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
50
+ for (size_t i = 0 ; i < reduce_size; ++i) {
51
+ reduce_dims[i] = reduce_dims_vec[i];
52
+ }
53
+ const auto eigen_out_grad_fp32 = EigenVector<float >::Flatten (out_grad_fp32);
54
+ auto & place = *dev_ctx.eigen_device ();
55
+ funcs::EigenBroadcastGrad<std::decay_t <decltype (place)>, float , Dims>::Eval (
56
+ place, eigen_x_grad, eigen_out_grad_fp32, reduce_dims, reshape_dims);
57
+ if constexpr (std::is_same_v<T, dtype::float16>) {
58
+ phi::CastKernel<float , Context>(
59
+ dev_ctx, x_grad_fp32, DataType::FLOAT16, x_grad);
60
+ } else {
61
+ phi::CastKernel<float , Context>(
62
+ dev_ctx, x_grad_fp32, DataType::BFLOAT16, x_grad);
63
+ }
64
+ } else {
65
+ auto eigen_x_grad = EigenVector<T>::Flatten (*x_grad);
66
+ Eigen::DSizes<Eigen::DenseIndex, Dims * 2 > reshape_dims;
67
+ for (size_t i = 0 ; i < reshape_size; ++i) {
68
+ reshape_dims[i] = reshape_dims_vec[i];
69
+ }
70
+ Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
71
+ for (size_t i = 0 ; i < reduce_size; ++i) {
72
+ reduce_dims[i] = reduce_dims_vec[i];
73
+ }
45
74
46
- auto eigen_out_grad = EigenVector<T>::Flatten (out_grad);
47
- auto & place = *dev_ctx.eigen_device ();
48
- funcs::EigenBroadcastGrad<std::decay_t <decltype (place)>, T, Dims>::Eval (
49
- place, eigen_x_grad, eigen_out_grad, reduce_dims, reshape_dims);
75
+ auto eigen_out_grad = EigenVector<T>::Flatten (out_grad);
76
+ auto & place = *dev_ctx.eigen_device ();
77
+ funcs::EigenBroadcastGrad<std::decay_t <decltype (place)>, T, Dims>::Eval (
78
+ place, eigen_x_grad, eigen_out_grad, reduce_dims, reshape_dims);
79
+ }
50
80
}
51
81
52
82
template <typename T, typename Context>
0 commit comments