14
14
15
15
#pragma once
16
16
17
+ #include " glog/logging.h"
17
18
#include " paddle/framework/eigen.h"
18
19
#include " paddle/framework/op_registry.h"
19
20
@@ -26,6 +27,10 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
26
27
typename IndexType = Eigen::DenseIndex>
27
28
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
28
29
30
+ template <typename T, int MajorType = Eigen::RowMajor,
31
+ typename IndexType = Eigen::DenseIndex>
32
+ using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
33
+
29
34
struct SumFunctor {
30
35
template <typename Place, typename X, typename Y, typename Dim>
31
36
void operator ()(const Place& place, X& x, Y& y, const Dim& dim) {
@@ -133,10 +138,17 @@ class ReduceKernel : public framework::OpKernel<T> {
133
138
dims_vector.erase (dims_vector.begin () + dim);
134
139
dims = framework::make_ddim (dims_vector);
135
140
}
136
- auto out = EigenTensor < T, D == 1 ? 1 : (D - 1 ) > :: From (*output, dims);
141
+
137
142
auto & place = context.GetEigenDevice <Place>();
138
143
Functor functor;
139
- functor (place, x, out, reduce_dim);
144
+
145
+ if (D == 1 ) {
146
+ auto out = EigenScalar<T>::From (*output);
147
+ functor (place, x, out, reduce_dim);
148
+ } else {
149
+ auto out = EigenTensor<T, (D - 1 )>::From (*output, dims);
150
+ functor (place, x, out, reduce_dim);
151
+ }
140
152
}
141
153
};
142
154
@@ -186,13 +198,13 @@ class ReduceGradKernel : public framework::OpKernel<T> {
186
198
auto x_reduce = EigenTensor<T, D>::From (*input1, dims);
187
199
auto x_reduce_grad = EigenTensor<T, D>::From (*input2, dims);
188
200
189
- Eigen::array<int , D> braodcast_dim ;
190
- for (size_t i = 0 ; i < D; ++i) braodcast_dim [i] = 1 ;
191
- braodcast_dim [dim] = input0->dims ()[dim];
201
+ Eigen::array<int , D> broadcast_dim ;
202
+ for (size_t i = 0 ; i < D; ++i) broadcast_dim [i] = 1 ;
203
+ broadcast_dim [dim] = input0->dims ()[dim];
192
204
auto & place = context.GetEigenDevice <Place>();
193
205
Functor functor;
194
- functor (place, x, x_reduce, x_grad, x_reduce_grad, braodcast_dim ,
195
- braodcast_dim [dim]);
206
+ functor (place, x, x_reduce, x_grad, x_reduce_grad, broadcast_dim ,
207
+ broadcast_dim [dim]);
196
208
}
197
209
};
198
210
0 commit comments