Skip to content

Commit cf07f3e

Browse files
authored
Merge pull request #5565 from dzhwinter/fix/reduce_op
Fix/reduce op
2 parents 2d7ac80 + 60232d8 commit cf07f3e

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

paddle/operators/reduce_op.h

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#pragma once
1616

17+
#include "glog/logging.h"
1718
#include "paddle/framework/eigen.h"
1819
#include "paddle/framework/op_registry.h"
1920

@@ -26,6 +27,10 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
2627
typename IndexType = Eigen::DenseIndex>
2728
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
2829

30+
template <typename T, int MajorType = Eigen::RowMajor,
31+
typename IndexType = Eigen::DenseIndex>
32+
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
33+
2934
struct SumFunctor {
3035
template <typename Place, typename X, typename Y, typename Dim>
3136
void operator()(const Place& place, X& x, Y& y, const Dim& dim) {
@@ -133,10 +138,17 @@ class ReduceKernel : public framework::OpKernel<T> {
133138
dims_vector.erase(dims_vector.begin() + dim);
134139
dims = framework::make_ddim(dims_vector);
135140
}
136-
auto out = EigenTensor < T, D == 1 ? 1 : (D - 1) > ::From(*output, dims);
141+
137142
auto& place = context.GetEigenDevice<Place>();
138143
Functor functor;
139-
functor(place, x, out, reduce_dim);
144+
145+
if (D == 1) {
146+
auto out = EigenScalar<T>::From(*output);
147+
functor(place, x, out, reduce_dim);
148+
} else {
149+
auto out = EigenTensor<T, (D - 1)>::From(*output, dims);
150+
functor(place, x, out, reduce_dim);
151+
}
140152
}
141153
};
142154

@@ -186,13 +198,13 @@ class ReduceGradKernel : public framework::OpKernel<T> {
186198
auto x_reduce = EigenTensor<T, D>::From(*input1, dims);
187199
auto x_reduce_grad = EigenTensor<T, D>::From(*input2, dims);
188200

189-
Eigen::array<int, D> braodcast_dim;
190-
for (size_t i = 0; i < D; ++i) braodcast_dim[i] = 1;
191-
braodcast_dim[dim] = input0->dims()[dim];
201+
Eigen::array<int, D> broadcast_dim;
202+
for (size_t i = 0; i < D; ++i) broadcast_dim[i] = 1;
203+
broadcast_dim[dim] = input0->dims()[dim];
192204
auto& place = context.GetEigenDevice<Place>();
193205
Functor functor;
194-
functor(place, x, x_reduce, x_grad, x_reduce_grad, braodcast_dim,
195-
braodcast_dim[dim]);
206+
functor(place, x, x_reduce, x_grad, x_reduce_grad, broadcast_dim,
207+
broadcast_dim[dim]);
196208
}
197209
};
198210

0 commit comments

Comments
 (0)