Skip to content

Commit 8630ba2

Browse files
author
Qingsheng Li
authored
Fix sequence expand op (#11618)
* Set zero outside functor
1 parent 01fbcb0 commit 8630ba2

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

paddle/fluid/operators/sequence_expand_op.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,6 @@ struct SequenceExpandGradFunctor<platform::CPUDeviceContext, T> {
151151
const framework::Vector<size_t>& x_lod, /*expand source lod*/
152152
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
153153
LoDTensor* dx) {
154-
math::SetConstant<platform::CPUDeviceContext, T> set_zero;
155-
set_zero(context, dx, static_cast<T>(0));
156-
157154
int dout_offset = 0;
158155
for (size_t i = 1; i < ref_lod.size(); ++i) {
159156
int repeat_num = ref_lod[i] - ref_lod[i - 1];
@@ -187,6 +184,10 @@ class SequenceExpandGradKernel : public framework::OpKernel<T> {
187184
g_x->mutable_data<T>(context.GetPlace());
188185
g_x->set_lod(x->lod());
189186

187+
auto& dev_ctx = context.template device_context<DeviceContext>();
188+
math::SetConstant<DeviceContext, T> set_zero;
189+
set_zero(dev_ctx, g_x, static_cast<T>(0));
190+
190191
auto& y_lod = y->lod();
191192
if (ref_level == -1) ref_level = y_lod.size() - 1;
192193
// just copy the gradient

0 commit comments

Comments
 (0)