Skip to content

Commit 8655904

Browse files
Enhance reduce op (#10708)
* Enhance reduce op for multi dims. * Uncomment some unitest. * Uncomment unitest. * Remove unused code. * Fix infershape and python wrapper. * Add more examples. * Fix l2_normalize. * Fix normalization_wrapper. * Polish code. 1. Rename unitest function. 2. Rename const variable.
1 parent 051a4b3 commit 8655904

File tree

4 files changed

+233
-85
lines changed

4 files changed

+233
-85
lines changed

paddle/fluid/operators/reduce_op.cc

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/reduce_op.h"
1616

17+
#include <algorithm>
1718
#include <string>
1819
#include <vector>
1920

@@ -34,11 +35,14 @@ class ReduceOp : public framework::OperatorWithKernel {
3435
auto x_dims = ctx->GetInputDim("X");
3536
auto x_rank = x_dims.size();
3637
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
37-
int dim = ctx->Attrs().Get<int>("dim");
38-
if (dim < 0) dim = x_rank + dim;
39-
PADDLE_ENFORCE_LT(
40-
dim, x_rank,
41-
"The dim should be in the range [-rank(input), rank(input)).");
38+
auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
39+
for (size_t i = 0; i < dims.size(); ++i) {
40+
if (dims[i] < 0) dims[i] = x_rank + dims[i];
41+
PADDLE_ENFORCE_LT(
42+
dims[i], x_rank,
43+
"The dim should be in the range [-rank(input), rank(input)).");
44+
}
45+
sort(dims.begin(), dims.end());
4246
bool reduce_all = ctx->Attrs().Get<bool>("reduce_all");
4347
bool keep_dim = ctx->Attrs().Get<bool>("keep_dim");
4448
if (reduce_all) {
@@ -49,14 +53,22 @@ class ReduceOp : public framework::OperatorWithKernel {
4953
ctx->SetOutputDim("Out", {1});
5054
} else {
5155
auto dims_vector = vectorize(x_dims);
52-
if (keep_dim || x_rank == 1) {
53-
dims_vector[dim] = 1;
56+
if (keep_dim) {
57+
for (size_t i = 0; i < dims.size(); ++i) {
58+
dims_vector[dims[i]] = 1;
59+
}
5460
} else {
55-
dims_vector.erase(dims_vector.begin() + dim);
61+
const int kDelFlag = -2;
62+
for (size_t i = 0; i < dims.size(); ++i) {
63+
dims_vector[dims[i]] = kDelFlag;
64+
}
65+
dims_vector.erase(
66+
remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
67+
dims_vector.end());
5668
}
5769
auto out_dims = framework::make_ddim(dims_vector);
5870
ctx->SetOutputDim("Out", out_dims);
59-
if (dim != 0) {
71+
if (dims[0] != 0) {
6072
// Only pass LoD when not reducing on the first dim.
6173
ctx->ShareLoD("X", /*->*/ "Out");
6274
}
@@ -75,11 +87,14 @@ class ReduceGradOp : public framework::OperatorWithKernel {
7587
auto x_dims = ctx->GetInputDim("X");
7688
auto x_rank = x_dims.size();
7789
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
78-
int dim = ctx->Attrs().Get<int>("dim");
79-
if (dim < 0) dim = x_rank + dim;
80-
PADDLE_ENFORCE_LT(
81-
dim, x_rank,
82-
"The dim should be in the range [-rank(input), rank(input)).");
90+
auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
91+
for (size_t i = 0; i < dims.size(); ++i) {
92+
if (dims[i] < 0) dims[i] = x_rank + dims[i];
93+
PADDLE_ENFORCE_LT(
94+
dims[i], x_rank,
95+
"The dim should be in the range [-rank(input), rank(input)).");
96+
}
97+
sort(dims.begin(), dims.end());
8398
auto x_grad_name = framework::GradVarName("X");
8499
if (ctx->HasOutput(x_grad_name)) {
85100
ctx->SetOutputDim(x_grad_name, x_dims);
@@ -95,13 +110,13 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
95110
"(Tensor) The input tensor. Tensors with rank at most 6 are "
96111
"supported.");
97112
AddOutput("Out", "(Tensor) The result tensor.");
98-
AddAttr<int>(
113+
AddAttr<std::vector<int>>(
99114
"dim",
100-
"(int, default 0) The dimension to reduce. "
115+
"(list<int>, default {0}) The dimensions to reduce. "
101116
"Must be in the range [-rank(input), rank(input)). "
102-
"If `dim < 0`, the dim to reduce is `rank + dim`. "
117+
"If `dim[i] < 0`, the dims[i] to reduce is `rank + dims[i]`. "
103118
"Note that reducing on the first dim will make the LoD info lost.")
104-
.SetDefault(0);
119+
.SetDefault({0});
105120
AddAttr<bool>("keep_dim",
106121
"(bool, default false) "
107122
"If true, retain the reduced dimension with length 1.")

paddle/fluid/operators/reduce_op.h

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <vector>
1718
#include "glog/logging.h"
1819
#include "paddle/fluid/framework/eigen.h"
1920
#include "paddle/fluid/framework/op_registry.h"
@@ -109,6 +110,11 @@ struct ProdGradFunctor {
109110
}
110111
};
111112

113+
#define HANDLE_DIM(NDIM, RDIM) \
114+
if (ndim == NDIM && rdim == RDIM) { \
115+
ReduceCompute<NDIM, RDIM>(context); \
116+
}
117+
112118
template <typename DeviceContext, typename T, typename Functor>
113119
class ReduceKernel : public framework::OpKernel<T> {
114120
public:
@@ -127,51 +133,56 @@ class ReduceKernel : public framework::OpKernel<T> {
127133
Functor functor;
128134
functor(place, &x, &out, reduce_dim);
129135
} else {
130-
int rank = context.Input<Tensor>("X")->dims().size();
131-
switch (rank) {
132-
case 1:
133-
ReduceCompute<1>(context);
134-
break;
135-
case 2:
136-
ReduceCompute<2>(context);
137-
break;
138-
case 3:
139-
ReduceCompute<3>(context);
140-
break;
141-
case 4:
142-
ReduceCompute<4>(context);
143-
break;
144-
case 5:
145-
ReduceCompute<5>(context);
146-
break;
147-
case 6:
148-
ReduceCompute<6>(context);
149-
break;
150-
}
136+
int ndim = context.Input<Tensor>("X")->dims().size();
137+
int rdim = context.Attr<std::vector<int>>("dim").size();
138+
HANDLE_DIM(6, 5);
139+
HANDLE_DIM(6, 4);
140+
HANDLE_DIM(6, 3);
141+
HANDLE_DIM(6, 2);
142+
HANDLE_DIM(6, 1);
143+
HANDLE_DIM(5, 4);
144+
HANDLE_DIM(5, 3);
145+
HANDLE_DIM(5, 2);
146+
HANDLE_DIM(5, 1);
147+
HANDLE_DIM(4, 3);
148+
HANDLE_DIM(4, 2);
149+
HANDLE_DIM(4, 1);
150+
HANDLE_DIM(3, 2);
151+
HANDLE_DIM(3, 1);
152+
HANDLE_DIM(2, 1);
153+
HANDLE_DIM(1, 1);
151154
}
152155
}
153156

154157
private:
155-
template <size_t D>
158+
template <size_t D, size_t R_D>
156159
void ReduceCompute(const framework::ExecutionContext& context) const {
157160
auto* input = context.Input<Tensor>("X");
158161
auto* output = context.Output<Tensor>("Out");
159162
output->mutable_data<T>(context.GetPlace());
160163

161164
auto x = EigenTensor<T, D>::From(*input);
162165
auto x_rank = static_cast<int>(x.dimensions().size());
163-
int dim = static_cast<int>(context.Attr<int>("dim"));
164-
if (dim < 0) dim = x_rank + dim;
165-
auto reduce_dim = Eigen::array<int, 1>({{dim}});
166+
auto dims = context.Attr<std::vector<int>>("dim");
167+
auto reduce_dim = Eigen::array<int, R_D>();
168+
for (size_t i = 0; i < dims.size(); ++i) {
169+
if (dims[i] < 0) dims[i] = x_rank + dims[i];
170+
reduce_dim[i] = dims[i];
171+
}
166172
// construct the squeezed output tensor
167173
bool keep_dim = context.Attr<bool>("keep_dim");
168-
DDim dims = output->dims();
169-
auto dims_vector = vectorize(dims);
174+
DDim out_dims = output->dims();
170175
if (keep_dim && x_rank > 1) {
171-
dims_vector.erase(dims_vector.begin() + dim);
172-
dims = framework::make_ddim(dims_vector);
176+
const int kDelFlag = -2;
177+
auto dims_vector = vectorize(out_dims);
178+
for (size_t i = 0; i < dims.size(); ++i) {
179+
dims_vector[dims[i]] = kDelFlag;
180+
}
181+
dims_vector.erase(
182+
remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
183+
dims_vector.end());
184+
out_dims = framework::make_ddim(dims_vector);
173185
}
174-
175186
auto& place =
176187
*context.template device_context<DeviceContext>().eigen_device();
177188
Functor functor;
@@ -180,7 +191,7 @@ class ReduceKernel : public framework::OpKernel<T> {
180191
auto out = EigenScalar<T>::From(*output);
181192
functor(place, &x, &out, reduce_dim);
182193
} else {
183-
auto out = EigenTensor<T, (D - 1)>::From(*output, dims);
194+
auto out = EigenTensor<T, (D - R_D)>::From(*output, out_dims);
184195
functor(place, &x, &out, reduce_dim);
185196
}
186197
}
@@ -245,21 +256,29 @@ class ReduceGradKernel : public framework::OpKernel<T> {
245256
auto x = EigenTensor<T, D>::From(*input0);
246257
auto x_grad = EigenTensor<T, D>::From(*output);
247258
auto x_rank = static_cast<int>(x.dimensions().size());
248-
int dim = static_cast<int>(context.Attr<int>("dim"));
249-
if (dim < 0) dim = x_rank + dim;
250-
DDim dims = input0->dims();
251-
dims[dim] = 1;
252-
auto x_reduce = EigenTensor<T, D>::From(*input1, dims);
253-
auto x_reduce_grad = EigenTensor<T, D>::From(*input2, dims);
254-
259+
auto dims = context.Attr<std::vector<int>>("dim");
260+
auto x_dims = input0->dims();
261+
auto reduced_dims_v = vectorize(x_dims);
255262
Eigen::array<int, D> broadcast_dim;
256263
for (size_t i = 0; i < D; ++i) broadcast_dim[i] = 1;
257-
broadcast_dim[dim] = input0->dims()[dim];
264+
265+
int broad_cats_times = 1;
266+
for (size_t i = 0; i < dims.size(); ++i) {
267+
if (dims[i] < 0) dims[i] = x_rank + dims[i];
268+
reduced_dims_v[dims[i]] = 1;
269+
broadcast_dim[dims[i]] = x_dims[dims[i]];
270+
broad_cats_times *= x_dims[dims[i]];
271+
}
272+
auto reduced_dims = framework::make_ddim(reduced_dims_v);
273+
auto x_reduce = EigenTensor<T, D>::From(*input1, reduced_dims);
274+
auto x_reduce_grad = EigenTensor<T, D>::From(*input2, reduced_dims);
275+
258276
auto& place =
259277
*context.template device_context<DeviceContext>().eigen_device();
278+
260279
Functor functor;
261280
functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim,
262-
broadcast_dim[dim]);
281+
broad_cats_times);
263282
}
264283
};
265284

0 commit comments

Comments
 (0)