25
25
#include " paddle/framework/op_registry.h"
26
26
#include " paddle/framework/operator.h"
27
27
28
+ #define MAX_RANK_SUPPORTED 6
29
+
28
30
#define EXPAND_TEMPLATE (z, n, data ) \
29
31
case n + 1 : { \
30
32
Expand<n + 1 >(context); \
31
33
break ; \
32
34
}
33
35
#define REP_EXPAND_TEMPLATE (n ) BOOST_PP_REPEAT(n, EXPAND_TEMPLATE, ~)
34
-
35
- #define COND (n ) BOOST_PP_GREATER_EQUAL(BOOST_PP_DIV(n, 6 ), BOOST_PP_MOD(n, 6 ))
36
+ #define COND (n ) \
37
+ BOOST_PP_GREATER_EQUAL (BOOST_PP_DIV(n, MAX_RANK_SUPPORTED), \
38
+ BOOST_PP_MOD(n, MAX_RANK_SUPPORTED))
36
39
#define EXPAND_GRAD_CASE (n ) \
37
40
case n: { \
38
41
ExpandBackward<n>(context, reshape_dims_vec, reduce_dims_vec); \
@@ -46,7 +49,6 @@ namespace paddle {
46
49
namespace operators {
47
50
48
51
using Tensor = framework::Tensor;
49
-
50
52
template <typename T, int MajorType = Eigen::RowMajor,
51
53
typename IndexType = Eigen::DenseIndex>
52
54
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
@@ -60,7 +62,7 @@ class ExpandKernel : public framework::OpKernel<T> {
60
62
void Compute (const framework::ExecutionContext& context) const override {
61
63
auto rank = context.Input <Tensor>(" X" )->dims ().size ();
62
64
switch (rank) {
63
- REP_EXPAND_TEMPLATE (6 )
65
+ REP_EXPAND_TEMPLATE (MAX_RANK_SUPPORTED )
64
66
default :
65
67
PADDLE_ENFORCE (false ,
66
68
" Only support tensor with rank being between 1 and 6." );
@@ -71,7 +73,7 @@ class ExpandKernel : public framework::OpKernel<T> {
71
73
template <int Rank>
72
74
void Expand (const framework::ExecutionContext& context) const {
73
75
auto * in0 = context.Input <Tensor>(" X" );
74
- auto & expand_times = context.Attr <std::vector<int >>(" expandTimes " );
76
+ auto & expand_times = context.Attr <std::vector<int >>(" expand_times " );
75
77
auto * out0 = context.Output <Tensor>(" Out" );
76
78
Eigen::DSizes<int , Rank> bcast_dims;
77
79
auto x_dims = in0->dims ();
@@ -91,8 +93,14 @@ class ExpandGradKernel : public framework::OpKernel<T> {
91
93
public:
92
94
void Compute (const framework::ExecutionContext& context) const override {
93
95
auto * in0 = context.Input <Tensor>(" X" );
94
- auto & expand_times = context.Attr <std::vector<int >>(" expandTimes " );
96
+ auto & expand_times = context.Attr <std::vector<int >>(" expand_times " );
95
97
auto x_dims = in0->dims ();
98
+ // 1. reshape_dims_vec is the broadcast parameter. For each dimension i,
99
+ // if expand_times[i] > 1 and x_dims[i] > 1, i will be splitted to two
100
+ // dimensions [expand_times[i], x_dims[i]].
101
+ // 2. reduce_dims_vec is the dimension parameter to compute gradients. For
102
+ // each dimension expanded, the gradients should be summed to original
103
+ // size.
96
104
std::vector<int > reshape_dims_vec;
97
105
std::vector<int > reduce_dims_vec;
98
106
for (size_t i = 0 ; i < expand_times.size (); ++i) {
@@ -110,7 +118,8 @@ class ExpandGradKernel : public framework::OpKernel<T> {
110
118
}
111
119
}
112
120
113
- int dims = reshape_dims_vec.size () * 6 + reduce_dims_vec.size () - 7 ;
121
+ int dims = reshape_dims_vec.size () * MAX_RANK_SUPPORTED +
122
+ reduce_dims_vec.size () - MAX_RANK_SUPPORTED - 1 ;
114
123
// no need reduce, just copy
115
124
if (reduce_dims_vec.size () == 0 ) {
116
125
auto * in0 = context.Input <Tensor>(framework::GradVarName (" Out" ));
@@ -132,8 +141,8 @@ class ExpandGradKernel : public framework::OpKernel<T> {
132
141
void ExpandBackward (const framework::ExecutionContext& context,
133
142
const std::vector<int >& reshape_dims_vec,
134
143
const std::vector<int >& reduce_dims_vec) const {
135
- size_t reshape_size = Dims / 6 + 1 ;
136
- size_t reduce_size = Dims % 6 + 1 ;
144
+ size_t reshape_size = Dims / MAX_RANK_SUPPORTED + 1 ;
145
+ size_t reduce_size = Dims % MAX_RANK_SUPPORTED + 1 ;
137
146
PADDLE_ENFORCE_EQ (reshape_size, reshape_dims_vec.size (),
138
147
" Inconsistent size between template Dims and "
139
148
" reshape dimensions." );
@@ -145,11 +154,11 @@ class ExpandGradKernel : public framework::OpKernel<T> {
145
154
auto x = EigenVector<T>::Flatten (*(context.Input <Tensor>(" X" )));
146
155
out0->mutable_data <T>(context.GetPlace ());
147
156
auto x_grad = EigenVector<T>::Flatten (*out0);
148
- Eigen::DSizes<int , Dims / 6 + 1 > reshape_dims;
157
+ Eigen::DSizes<int , Dims / MAX_RANK_SUPPORTED + 1 > reshape_dims;
149
158
for (size_t i = 0 ; i < reshape_size; ++i) {
150
159
reshape_dims[i] = reshape_dims_vec[i];
151
160
}
152
- Eigen::DSizes<int , Dims % 6 + 1 > reduce_dims;
161
+ Eigen::DSizes<int , Dims % MAX_RANK_SUPPORTED + 1 > reduce_dims;
153
162
for (size_t i = 0 ; i < reduce_size; ++i) {
154
163
reduce_dims[i] = reduce_dims_vec[i];
155
164
}
0 commit comments