@@ -29,22 +29,13 @@ template <typename T, int MajorType = Eigen::RowMajor,
29
29
typename IndexType = Eigen::DenseIndex>
30
30
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
31
31
32
- enum SeqPoolType {
33
- AVERAGE = 0 ,
34
- SUM = 1 ,
35
- SQRT = 2 , // square_root_n
36
- MAX = 3 ,
37
- LAST = 4 ,
38
- FIRST = 5
39
- };
40
-
41
32
template <typename Place, typename T>
42
33
class SequencePoolKernel : public framework ::OpKernel<T> {
43
34
public:
44
35
void Compute (const framework::ExecutionContext& context) const override {
45
36
auto * in = context.Input <LoDTensor>(" X" );
46
37
auto * out = context.Output <LoDTensor>(" Out" );
47
- int strategy = context.Attr <int >(" strategy " );
38
+ std::string pooltype = context.Attr <std::string >(" pooltype " );
48
39
49
40
auto dims = in->dims ();
50
41
auto lod = in->lod ();
@@ -71,28 +62,21 @@ class SequencePoolKernel : public framework::OpKernel<T> {
71
62
auto in_e = EigenMatrix<T>::From (in_t , framework::make_ddim ({h, w}));
72
63
auto out_e = EigenVector<T>::Flatten (out_t );
73
64
74
- switch (strategy) {
75
- case AVERAGE:
76
- out_e.device (place) = in_e.mean (Eigen::array<int , 1 >({{0 }}));
77
- break ;
78
- case SUM:
79
- out_e.device (place) = in_e.sum (Eigen::array<int , 1 >({{0 }}));
80
- break ;
81
- case SQRT:
82
- out_e.device (place) = in_e.sum (Eigen::array<int , 1 >({{0 }})) /
83
- std::sqrt (static_cast <T>(h));
84
- break ;
85
- case MAX:
86
- out_e.device (place) = in_e.maximum (Eigen::array<int , 1 >({{0 }}));
87
- break ;
88
- case LAST:
89
- out_e.device (place) = in_e.chip (h - 1 , 0 );
90
- break ;
91
- case FIRST:
92
- out_e.device (place) = in_e.chip (0 , 0 );
93
- break ;
94
- default :
95
- PADDLE_THROW (" unsupported pooling strategy" );
65
+ if (pooltype == " AVERAGE" ) {
66
+ out_e.device (place) = in_e.mean (Eigen::array<int , 1 >({{0 }}));
67
+ } else if (pooltype == " SUM" ) {
68
+ out_e.device (place) = in_e.sum (Eigen::array<int , 1 >({{0 }}));
69
+ } else if (pooltype == " SQRT" ) {
70
+ out_e.device (place) = in_e.sum (Eigen::array<int , 1 >({{0 }})) /
71
+ std::sqrt (static_cast <T>(h));
72
+ } else if (pooltype == " MAX" ) {
73
+ out_e.device (place) = in_e.maximum (Eigen::array<int , 1 >({{0 }}));
74
+ } else if (pooltype == " LAST" ) {
75
+ out_e.device (place) = in_e.chip (h - 1 , 0 );
76
+ } else if (pooltype == " FIRST" ) {
77
+ out_e.device (place) = in_e.chip (0 , 0 );
78
+ } else {
79
+ PADDLE_THROW (" unsupported pooling pooltype" );
96
80
}
97
81
}
98
82
}
@@ -105,15 +89,15 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
105
89
auto * in = context.Input <LoDTensor>(" X" );
106
90
auto * in_g = context.Output <LoDTensor>(framework::GradVarName (" X" ));
107
91
auto * out_g = context.Input <LoDTensor>(framework::GradVarName (" Out" ));
108
- int strategy = context.Attr <int >(" strategy " );
92
+ std::string pooltype = context.Attr <std::string >(" pooltype " );
109
93
110
94
auto dims = in->dims ();
111
95
auto lod = in->lod ()[0 ];
112
96
int64_t w = in->numel () / dims[0 ];
113
97
114
98
in_g->mutable_data <T>(context.GetPlace ());
115
- if (strategy == LAST || strategy == FIRST) {
116
- // set X@Grad be zero at first when strategy is LAST/FIRST
99
+ if (pooltype == " LAST" || pooltype == " FIRST" ) {
100
+ // set X@Grad be zero at first when pooltype is LAST/FIRST
117
101
math::SetConstant<Place, T> functor;
118
102
functor (context.device_context (), in_g, 0 );
119
103
}
@@ -127,41 +111,33 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
127
111
auto out_g_e = EigenMatrix<T>::From (out_g_t , {1 , w});
128
112
Eigen::DSizes<int , 2 > bcast (h, 1 );
129
113
130
- switch (strategy) {
131
- case AVERAGE:
132
- in_g_e.device (place) = (out_g_e / static_cast <T>(h)).broadcast (bcast);
133
- break ;
134
- case SUM:
135
- in_g_e.device (place) = (out_g_e).broadcast (bcast);
136
- break ;
137
- case SQRT:
138
- in_g_e.device (place) =
139
- (out_g_e / std::sqrt (static_cast <T>(h))).broadcast (bcast);
140
- break ;
141
- case MAX: {
142
- auto in_t =
143
- in->Slice (static_cast <int >(lod[i]), static_cast <int >(lod[i + 1 ]));
144
- Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
145
- in_t_map (in_t .data <T>(), h, w);
146
- int row_id;
147
- Eigen::array<int , 2 > extents{{1 , 1 }};
148
- for (int col_id = 0 ; col_id < w; col_id++) {
149
- in_t_map.col (col_id).maxCoeff (&row_id);
150
- Eigen::array<int , 2 > in_offsets{{row_id, col_id}};
151
- Eigen::array<int , 2 > out_offsets{{0 , col_id}};
152
- in_g_e.slice (in_offsets, extents).device (place) =
153
- out_g_e.slice (out_offsets, extents);
154
- }
155
- break ;
114
+ if (pooltype == " AVERAGE" ) {
115
+ in_g_e.device (place) = (out_g_e / static_cast <T>(h)).broadcast (bcast);
116
+ } else if (pooltype == " SUM" ) {
117
+ in_g_e.device (place) = (out_g_e).broadcast (bcast);
118
+ } else if (pooltype == " SQRT" ) {
119
+ in_g_e.device (place) =
120
+ (out_g_e / std::sqrt (static_cast <T>(h))).broadcast (bcast);
121
+ } else if (pooltype == " MAX" ) {
122
+ auto in_t =
123
+ in->Slice (static_cast <int >(lod[i]), static_cast <int >(lod[i + 1 ]));
124
+ Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
125
+ in_t_map (in_t .data <T>(), h, w);
126
+ int row_id;
127
+ Eigen::array<int , 2 > extents{{1 , 1 }};
128
+ for (int col_id = 0 ; col_id < w; col_id++) {
129
+ in_t_map.col (col_id).maxCoeff (&row_id);
130
+ Eigen::array<int , 2 > in_offsets{{row_id, col_id}};
131
+ Eigen::array<int , 2 > out_offsets{{0 , col_id}};
132
+ in_g_e.slice (in_offsets, extents).device (place) =
133
+ out_g_e.slice (out_offsets, extents);
156
134
}
157
- case LAST:
158
- in_g_e.chip (h - 1 , 0 ).device (place) = out_g_e;
159
- break ;
160
- case FIRST:
161
- in_g_e.chip (0 , 0 ).device (place) = out_g_e;
162
- break ;
163
- default :
164
- PADDLE_THROW (" unsupported pooling strategy" );
135
+ } else if (pooltype == " LAST" ) {
136
+ in_g_e.chip (h - 1 , 0 ).device (place) = out_g_e;
137
+ } else if (pooltype == " FIRST" ) {
138
+ in_g_e.chip (0 , 0 ).device (place) = out_g_e;
139
+ } else {
140
+ PADDLE_THROW (" unsupported pooling pooltype" );
165
141
}
166
142
}
167
143
}
0 commit comments