@@ -14,6 +14,7 @@ limitations under the License. */
14
14
15
15
#pragma once
16
16
17
+ #include < vector>
17
18
#include " glog/logging.h"
18
19
#include " paddle/fluid/framework/eigen.h"
19
20
#include " paddle/fluid/framework/op_registry.h"
@@ -109,6 +110,11 @@ struct ProdGradFunctor {
109
110
}
110
111
};
111
112
113
+ #define HANDLE_DIM (NDIM, RDIM ) \
114
+ if (ndim == NDIM && rdim == RDIM) { \
115
+ ReduceCompute<NDIM, RDIM>(context); \
116
+ }
117
+
112
118
template <typename DeviceContext, typename T, typename Functor>
113
119
class ReduceKernel : public framework ::OpKernel<T> {
114
120
public:
@@ -127,51 +133,56 @@ class ReduceKernel : public framework::OpKernel<T> {
127
133
Functor functor;
128
134
functor (place, &x, &out, reduce_dim);
129
135
} else {
130
- int rank = context.Input <Tensor>(" X" )->dims ().size ();
131
- switch (rank) {
132
- case 1 :
133
- ReduceCompute<1 >(context);
134
- break ;
135
- case 2 :
136
- ReduceCompute<2 >(context);
137
- break ;
138
- case 3 :
139
- ReduceCompute<3 >(context);
140
- break ;
141
- case 4 :
142
- ReduceCompute<4 >(context);
143
- break ;
144
- case 5 :
145
- ReduceCompute<5 >(context);
146
- break ;
147
- case 6 :
148
- ReduceCompute<6 >(context);
149
- break ;
150
- }
136
+ int ndim = context.Input <Tensor>(" X" )->dims ().size ();
137
+ int rdim = context.Attr <std::vector<int >>(" dim" ).size ();
138
+ HANDLE_DIM (6 , 5 );
139
+ HANDLE_DIM (6 , 4 );
140
+ HANDLE_DIM (6 , 3 );
141
+ HANDLE_DIM (6 , 2 );
142
+ HANDLE_DIM (6 , 1 );
143
+ HANDLE_DIM (5 , 4 );
144
+ HANDLE_DIM (5 , 3 );
145
+ HANDLE_DIM (5 , 2 );
146
+ HANDLE_DIM (5 , 1 );
147
+ HANDLE_DIM (4 , 3 );
148
+ HANDLE_DIM (4 , 2 );
149
+ HANDLE_DIM (4 , 1 );
150
+ HANDLE_DIM (3 , 2 );
151
+ HANDLE_DIM (3 , 1 );
152
+ HANDLE_DIM (2 , 1 );
153
+ HANDLE_DIM (1 , 1 );
151
154
}
152
155
}
153
156
154
157
private:
155
- template <size_t D>
158
+ template <size_t D, size_t R_D >
156
159
void ReduceCompute (const framework::ExecutionContext& context) const {
157
160
auto * input = context.Input <Tensor>(" X" );
158
161
auto * output = context.Output <Tensor>(" Out" );
159
162
output->mutable_data <T>(context.GetPlace ());
160
163
161
164
auto x = EigenTensor<T, D>::From (*input);
162
165
auto x_rank = static_cast <int >(x.dimensions ().size ());
163
- int dim = static_cast <int >(context.Attr <int >(" dim" ));
164
- if (dim < 0 ) dim = x_rank + dim;
165
- auto reduce_dim = Eigen::array<int , 1 >({{dim}});
166
+ auto dims = context.Attr <std::vector<int >>(" dim" );
167
+ auto reduce_dim = Eigen::array<int , R_D>();
168
+ for (size_t i = 0 ; i < dims.size (); ++i) {
169
+ if (dims[i] < 0 ) dims[i] = x_rank + dims[i];
170
+ reduce_dim[i] = dims[i];
171
+ }
166
172
// construct the squeezed output tensor
167
173
bool keep_dim = context.Attr <bool >(" keep_dim" );
168
- DDim dims = output->dims ();
169
- auto dims_vector = vectorize (dims);
174
+ DDim out_dims = output->dims ();
170
175
if (keep_dim && x_rank > 1 ) {
171
- dims_vector.erase (dims_vector.begin () + dim);
172
- dims = framework::make_ddim (dims_vector);
176
+ const int kDelFlag = -2 ;
177
+ auto dims_vector = vectorize (out_dims);
178
+ for (size_t i = 0 ; i < dims.size (); ++i) {
179
+ dims_vector[dims[i]] = kDelFlag ;
180
+ }
181
+ dims_vector.erase (
182
+ remove (dims_vector.begin (), dims_vector.end (), kDelFlag ),
183
+ dims_vector.end ());
184
+ out_dims = framework::make_ddim (dims_vector);
173
185
}
174
-
175
186
auto & place =
176
187
*context.template device_context <DeviceContext>().eigen_device ();
177
188
Functor functor;
@@ -180,7 +191,7 @@ class ReduceKernel : public framework::OpKernel<T> {
180
191
auto out = EigenScalar<T>::From (*output);
181
192
functor (place, &x, &out, reduce_dim);
182
193
} else {
183
- auto out = EigenTensor<T, (D - 1 )>::From (*output, dims );
194
+ auto out = EigenTensor<T, (D - R_D )>::From (*output, out_dims );
184
195
functor (place, &x, &out, reduce_dim);
185
196
}
186
197
}
@@ -245,21 +256,29 @@ class ReduceGradKernel : public framework::OpKernel<T> {
245
256
auto x = EigenTensor<T, D>::From (*input0);
246
257
auto x_grad = EigenTensor<T, D>::From (*output);
247
258
auto x_rank = static_cast <int >(x.dimensions ().size ());
248
- int dim = static_cast <int >(context.Attr <int >(" dim" ));
249
- if (dim < 0 ) dim = x_rank + dim;
250
- DDim dims = input0->dims ();
251
- dims[dim] = 1 ;
252
- auto x_reduce = EigenTensor<T, D>::From (*input1, dims);
253
- auto x_reduce_grad = EigenTensor<T, D>::From (*input2, dims);
254
-
259
+ auto dims = context.Attr <std::vector<int >>(" dim" );
260
+ auto x_dims = input0->dims ();
261
+ auto reduced_dims_v = vectorize (x_dims);
255
262
Eigen::array<int , D> broadcast_dim;
256
263
for (size_t i = 0 ; i < D; ++i) broadcast_dim[i] = 1 ;
257
- broadcast_dim[dim] = input0->dims ()[dim];
264
+
265
+ int broad_cats_times = 1 ;
266
+ for (size_t i = 0 ; i < dims.size (); ++i) {
267
+ if (dims[i] < 0 ) dims[i] = x_rank + dims[i];
268
+ reduced_dims_v[dims[i]] = 1 ;
269
+ broadcast_dim[dims[i]] = x_dims[dims[i]];
270
+ broad_cats_times *= x_dims[dims[i]];
271
+ }
272
+ auto reduced_dims = framework::make_ddim (reduced_dims_v);
273
+ auto x_reduce = EigenTensor<T, D>::From (*input1, reduced_dims);
274
+ auto x_reduce_grad = EigenTensor<T, D>::From (*input2, reduced_dims);
275
+
258
276
auto & place =
259
277
*context.template device_context <DeviceContext>().eigen_device ();
278
+
260
279
Functor functor;
261
280
functor (place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim,
262
- broadcast_dim[dim] );
281
+ broad_cats_times );
263
282
}
264
283
};
265
284
0 commit comments