@@ -83,7 +83,8 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
83
83
PADDLE_ENFORCE_LT (
84
84
lod[0 ][i] + offset_data[i] + length_data[i],
85
85
lod[0 ][i + 1 ],
86
- " The target tensor's length overflow" )}
86
+ " The target tensor's length overflow" )
87
+ }
87
88
88
89
out->mutable_data <T>(ctx.GetPlace ());
89
90
auto out_lod = SequenceSliceLoD (*in, offset_data, length_data);
@@ -140,27 +141,29 @@ class SequenceSliceGradOpKernel : public framework::OpKernel<T> {
140
141
auto lod = in->lod ();
141
142
auto out_lod = out_grad->lod ();
142
143
143
- x_grad->mutable_data <T>(ctx.GetPlace ());
144
- math::SetConstant<Place, T> set_zero;
145
- set_zero (ctx.device_context (), x_grad, static_cast <T>(0 ));
144
+ if (x_grad) {
145
+ x_grad->mutable_data <T>(ctx.GetPlace ());
146
+ math::SetConstant<Place, T> set_zero;
147
+ set_zero (ctx.device_context (), x_grad, static_cast <T>(0 ));
146
148
147
- auto out_grad_stride = framework::stride (out_grad->dims ());
149
+ auto out_grad_stride = framework::stride (out_grad->dims ());
148
150
149
- for (size_t i = 0 ; i < out_lod[0 ].size () - 1 ; ++i) {
150
- Tensor out_grad_t =
151
- out_grad->Slice (static_cast <int >(out_lod[0 ][i]),
152
- static_cast <int >(out_lod[0 ][i + 1 ]));
153
- auto out_grad_stride = framework::stride (out_grad_t .dims ());
151
+ for (size_t i = 0 ; i < out_lod[0 ].size () - 1 ; ++i) {
152
+ Tensor out_grad_t =
153
+ out_grad->Slice (static_cast <int >(out_lod[0 ][i]),
154
+ static_cast <int >(out_lod[0 ][i + 1 ]));
155
+ auto out_grad_stride = framework::stride (out_grad_t .dims ());
154
156
155
- auto x_grad_stride = framework::stride (x_grad->dims ());
157
+ auto x_grad_stride = framework::stride (x_grad->dims ());
156
158
157
- Tensor x_grad_t = x_grad->Slice (
158
- static_cast <int >(lod[0 ][i] + offset_data[i]),
159
- static_cast <int >(lod[0 ][i] + offset_data[i] + length_data[i]));
159
+ Tensor x_grad_t = x_grad->Slice (
160
+ static_cast <int >(lod[0 ][i] + offset_data[i]),
161
+ static_cast <int >(lod[0 ][i] + offset_data[i] + length_data[i]));
160
162
161
- StridedMemcpy<T>(ctx.device_context (), out_grad_t .data <T>(),
162
- out_grad_stride, out_grad_t .dims (), x_grad_stride,
163
- x_grad_t .data <T>());
163
+ StridedMemcpy<T>(ctx.device_context (), out_grad_t .data <T>(),
164
+ out_grad_stride, out_grad_t .dims (), x_grad_stride,
165
+ x_grad_t .data <T>());
166
+ }
164
167
}
165
168
}
166
169
};
0 commit comments