@@ -150,30 +150,17 @@ class StackKernel : public framework::OpKernel<T> {
150
150
int total_num = pre * n * post;
151
151
152
152
auto &dev_ctx = ctx.template device_context <DeviceContext>();
153
- constexpr auto kMaxThreshold = 16 ;
154
- if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value ||
155
- n > kMaxThreshold ) {
156
153
#ifdef __NVCC__
157
- VLOG (10 ) << " Stack more than " << kMaxThreshold
158
- << " tensors on GPU may be slow." ;
159
- thrust::device_vector<const T *> device_x_vec (x_datas);
160
- auto x_data_arr = device_x_vec.data ().get ();
154
+ thrust::device_vector<const T *> device_x_vec (x_datas);
155
+ auto x_data_arr = device_x_vec.data ().get ();
161
156
#else
162
- auto x_data_arr = x_datas.data ();
157
+ auto x_data_arr = x_datas.data ();
163
158
#endif
164
- StackFunctorForRange (dev_ctx, x_data_arr, y_data, total_num, n, post);
159
+ StackFunctorForRange (dev_ctx, x_data_arr, y_data, total_num, n, post);
165
160
#ifdef __NVCC__
166
- // Wait() must be called because device_x_vec may be destructed before
167
- // kernel ends
168
- dev_ctx.Wait ();
169
- #endif
170
- }
171
- #ifdef __NVCC__
172
- else { // NOLINT
173
- framework::Array<const T *, kMaxThreshold > x_data_arr;
174
- for (int i = 0 ; i < n; ++i) x_data_arr[i] = x_datas[i];
175
- StackFunctorForRange (dev_ctx, x_data_arr, y_data, total_num, n, post);
176
- }
161
+ // Wait() must be called because device_x_vec may be destructed before
162
+ // kernel ends
163
+ dev_ctx.Wait ();
177
164
#endif
178
165
}
179
166
};
@@ -244,32 +231,17 @@ class StackGradKernel : public framework::OpKernel<T> {
244
231
int post = total_num / (n * pre);
245
232
246
233
auto &dev_ctx = ctx.template device_context <DeviceContext>();
247
- constexpr auto kMaxThreshold = 16 ;
248
- if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value ||
249
- n > kMaxThreshold ) {
250
234
#ifdef __NVCC__
251
- VLOG (10 ) << " Stack more than " << kMaxThreshold
252
- << " tensors on GPU may be slow." ;
253
- thrust::device_vector<T *> device_dx_vec (dx_datas);
254
- auto dx_data_arr = device_dx_vec.data ().get ();
235
+ thrust::device_vector<T *> device_dx_vec (dx_datas);
236
+ auto dx_data_arr = device_dx_vec.data ().get ();
255
237
#else
256
- auto dx_data_arr = dx_datas.data ();
238
+ auto dx_data_arr = dx_datas.data ();
257
239
#endif
258
- StackGradFunctorForRange (dev_ctx, dx_data_arr, dy_data, total_num, n,
259
- post);
240
+ StackGradFunctorForRange (dev_ctx, dx_data_arr, dy_data, total_num, n, post);
260
241
#ifdef __NVCC__
261
- // Wait() must be called because device_dx_vec may be destructed before
262
- // kernel ends
263
- dev_ctx.Wait ();
264
- #endif
265
- }
266
- #ifdef __NVCC__
267
- else { // NOLINT
268
- framework::Array<T *, kMaxThreshold > dx_data_arr;
269
- for (int i = 0 ; i < n; ++i) dx_data_arr[i] = dx_datas[i];
270
- StackGradFunctorForRange (dev_ctx, dx_data_arr, dy_data, total_num, n,
271
- post);
272
- }
242
+ // Wait() must be called because device_dx_vec may be destructed before
243
+ // kernel ends
244
+ dev_ctx.Wait ();
273
245
#endif
274
246
}
275
247
};
0 commit comments