File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -126,6 +126,7 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
126
126
int64_t h = static_cast <int64_t >(lod[i + 1 ] - lod[i]);
127
127
auto in_g_e = EigenMatrix<T>::From (in_g_t , {h, w});
128
128
auto out_g_e = EigenMatrix<T>::From (out_g_t , {1 , w});
129
+ auto out_g_e_v = EigenVector<T>::Flatten (out_g_t );
129
130
Eigen::DSizes<int , 2 > bcast (h, 1 );
130
131
131
132
if (pooltype == " AVERAGE" ) {
@@ -136,9 +137,9 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
136
137
in_g_e.device (place) =
137
138
(out_g_e / std::sqrt (static_cast <T>(h))).broadcast (bcast);
138
139
} else if (pooltype == " LAST" ) {
139
- in_g_e.chip (h - 1 , 0 ).device (place) = out_g_e ;
140
+ in_g_e.chip (h - 1 , 0 ).device (place) = out_g_e_v ;
140
141
} else if (pooltype == " FIRST" ) {
141
- in_g_e.chip (0 , 0 ).device (place) = out_g_e ;
142
+ in_g_e.chip (0 , 0 ).device (place) = out_g_e_v ;
142
143
} else {
143
144
PADDLE_THROW (" unsupported pooling pooltype" );
144
145
}
You can’t perform that action at this time.
0 commit comments