@@ -235,11 +235,13 @@ struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> {
235
235
236
236
int g_find_max;
237
237
memory::Copy (platform::CPUPlace (), &g_find_max, gpu_place, find_max,
238
- sizeof (int ), 0 );
238
+ sizeof (int ), ctx.stream ());
239
+ ctx.Wait ();
239
240
if (g_find_max) {
240
241
int len;
241
242
memory::Copy (platform::CPUPlace (), &len, gpu_place, out_size_data,
242
- sizeof (int ), 0 );
243
+ sizeof (int ), ctx.stream ());
244
+ ctx.Wait ();
243
245
FindAbsMaxFunctor<platform::CUDADeviceContext, T>()(ctx, scale_arr, len,
244
246
out_scale_data);
245
247
}
@@ -258,25 +260,26 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
258
260
const auto gpu_place = boost::get<platform::CUDAPlace>(ctx.GetPlace ());
259
261
260
262
T accum;
261
- memory::Copy (platform::CPUPlace (), &accum, gpu_place, in_accum.data <T>(),
262
- sizeof (T), 0 );
263
263
T state;
264
- memory::Copy (platform::CPUPlace (), &state, gpu_place, in_state.data <T>(),
265
- sizeof (T), 0 );
266
264
T scale;
265
+ memory::Copy (platform::CPUPlace (), &accum, gpu_place, in_accum.data <T>(),
266
+ sizeof (T), ctx.stream ());
267
+ memory::Copy (platform::CPUPlace (), &state, gpu_place, in_state.data <T>(),
268
+ sizeof (T), ctx.stream ());
267
269
memory::Copy (platform::CPUPlace (), &scale, gpu_place, cur_scale, sizeof (T),
268
- 0 );
269
-
270
+ ctx. stream () );
271
+ ctx. Wait ();
270
272
state = rate * state + 1 ;
271
273
accum = rate * accum + scale;
272
274
scale = accum / state;
273
275
274
276
memory::Copy (gpu_place, out_accum->mutable_data <T>(gpu_place),
275
- platform::CPUPlace (), &accum, sizeof (T), 0 );
277
+ platform::CPUPlace (), &accum, sizeof (T), ctx. stream () );
276
278
memory::Copy (gpu_place, out_state->mutable_data <T>(gpu_place),
277
- platform::CPUPlace (), &state, sizeof (T), 0 );
279
+ platform::CPUPlace (), &state, sizeof (T), ctx. stream () );
278
280
memory::Copy (gpu_place, out_scale->mutable_data <T>(gpu_place),
279
- platform::CPUPlace (), &scale, sizeof (T), 0 );
281
+ platform::CPUPlace (), &scale, sizeof (T), ctx.stream ());
282
+ ctx.Wait ();
280
283
}
281
284
};
282
285
0 commit comments