@@ -154,17 +154,22 @@ class StackKernel : public framework::OpKernel<T> {
154
154
if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value ||
155
155
n > kMaxThreshold ) {
156
156
#ifdef __NVCC__
157
+ VLOG (10 ) << " Stack more than " << kMaxThreshold
158
+ << " tensors on GPU may be slow." ;
157
159
thrust::device_vector<const T *> device_x_vec (x_datas);
158
160
auto x_data_arr = device_x_vec.data ().get ();
159
161
#else
160
162
auto x_data_arr = x_datas.data ();
161
163
#endif
162
164
StackFunctorForRange (dev_ctx, x_data_arr, y_data, total_num, n, post);
165
+ #ifdef __NVCC__
166
+ // Wait() must be called because device_x_vec may be destructed before
167
+ // kernel ends
168
+ dev_ctx.Wait ();
169
+ #endif
163
170
}
164
171
#ifdef __NVCC__
165
172
else { // NOLINT
166
- VLOG (10 ) << " Stack more than " << kMaxThreshold
167
- << " tensors on GPU may be slow." ;
168
173
framework::Array<const T *, kMaxThreshold > x_data_arr;
169
174
for (int i = 0 ; i < n; ++i) x_data_arr[i] = x_datas[i];
170
175
StackFunctorForRange (dev_ctx, x_data_arr, y_data, total_num, n, post);
@@ -243,18 +248,23 @@ class StackGradKernel : public framework::OpKernel<T> {
243
248
if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value ||
244
249
n > kMaxThreshold ) {
245
250
#ifdef __NVCC__
251
+ VLOG (10 ) << " Stack more than " << kMaxThreshold
252
+ << " tensors on GPU may be slow." ;
246
253
thrust::device_vector<T *> device_dx_vec (dx_datas);
247
254
auto dx_data_arr = device_dx_vec.data ().get ();
248
255
#else
249
256
auto dx_data_arr = dx_datas.data ();
250
257
#endif
251
258
StackGradFunctorForRange (dev_ctx, dx_data_arr, dy_data, total_num, n,
252
259
post);
260
+ #ifdef __NVCC__
261
+ // Wait() must be called because device_dx_vec may be destructed before
262
+ // kernel ends
263
+ dev_ctx.Wait ();
264
+ #endif
253
265
}
254
266
#ifdef __NVCC__
255
267
else { // NOLINT
256
- VLOG (10 ) << " Stack more than " << kMaxThreshold
257
- << " tensors on GPU may be slow." ;
258
268
framework::Array<T *, kMaxThreshold > dx_data_arr;
259
269
for (int i = 0 ; i < n; ++i) dx_data_arr[i] = dx_datas[i];
260
270
StackGradFunctorForRange (dev_ctx, dx_data_arr, dy_data, total_num, n,
0 commit comments