Skip to content

Commit 8483578

Browse files
committed
fix shape bug
1 parent 58b4c9a commit 8483578

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

paddle/operators/reduce_op.h

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#pragma once
1616

17+
#include "glog/logging.h"
1718
#include "paddle/framework/eigen.h"
1819
#include "paddle/framework/op_registry.h"
1920

@@ -26,6 +27,10 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
2627
typename IndexType = Eigen::DenseIndex>
2728
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
2829

30+
template <typename T, int MajorType = Eigen::RowMajor,
31+
typename IndexType = Eigen::DenseIndex>
32+
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
33+
2934
struct SumFunctor {
3035
template <typename Place, typename X, typename Y, typename Dim>
3136
void operator()(const Place& place, X& x, Y& y, const Dim& dim) {
@@ -133,10 +138,21 @@ class ReduceKernel : public framework::OpKernel<T> {
133138
dims_vector.erase(dims_vector.begin() + dim);
134139
dims = framework::make_ddim(dims_vector);
135140
}
136-
auto out = EigenTensor < T, D == 1 ? 1 : (D - 1) > ::From(*output, dims);
141+
137142
auto& place = context.GetEigenDevice<Place>();
138143
Functor functor;
139-
functor(place, x, out, reduce_dim);
144+
145+
if (D == 1) {
146+
auto out = EigenScalar<T>::From(*output);
147+
// auto out = EigenTensor<T, 1>::From(*output, dims);
148+
VLOG(0) << "x dims : " << x.rank() << " out dims : " << out.rank();
149+
functor(place, x, out, reduce_dim);
150+
} else {
151+
auto out = EigenTensor<T, (D - 1)>::From(*output, dims);
152+
// VLOG(0) << "x dims : "<< x.dimensions().size() << " out dims : "
153+
// << out.dimensions().size();
154+
functor(place, x, out, reduce_dim);
155+
}
140156
}
141157
};
142158

@@ -186,13 +202,13 @@ class ReduceGradKernel : public framework::OpKernel<T> {
186202
auto x_reduce = EigenTensor<T, D>::From(*input1, dims);
187203
auto x_reduce_grad = EigenTensor<T, D>::From(*input2, dims);
188204

189-
Eigen::array<int, D> braodcast_dim;
190-
for (size_t i = 0; i < D; ++i) braodcast_dim[i] = 1;
191-
braodcast_dim[dim] = input0->dims()[dim];
205+
Eigen::array<int, D> broadcast_dim;
206+
for (size_t i = 0; i < D; ++i) broadcast_dim[i] = 1;
207+
broadcast_dim[dim] = input0->dims()[dim];
192208
auto& place = context.GetEigenDevice<Place>();
193209
Functor functor;
194-
functor(place, x, x_reduce, x_grad, x_reduce_grad, braodcast_dim,
195-
braodcast_dim[dim]);
210+
functor(place, x, x_reduce, x_grad, x_reduce_grad, broadcast_dim,
211+
broadcast_dim[dim]);
196212
}
197213
};
198214

0 commit comments

Comments
 (0)