Skip to content

Commit 7b45363

Browse files
authored
Merge pull request #16855 from wzzju/fix_quantize_op
fix the hang bugs of memory copying. test=develop
2 parents cf5af3b + 24db3a7 commit 7b45363

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

paddle/fluid/operators/fake_quantize_op.cu

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,13 @@ struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> {
235235

236236
int g_find_max;
237237
memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max,
238-
sizeof(int), 0);
238+
sizeof(int), ctx.stream());
239+
ctx.Wait();
239240
if (g_find_max) {
240241
int len;
241242
memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data,
242-
sizeof(int), 0);
243+
sizeof(int), ctx.stream());
244+
ctx.Wait();
243245
FindAbsMaxFunctor<platform::CUDADeviceContext, T>()(ctx, scale_arr, len,
244246
out_scale_data);
245247
}
@@ -258,25 +260,26 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
258260
const auto gpu_place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
259261

260262
T accum;
261-
memory::Copy(platform::CPUPlace(), &accum, gpu_place, in_accum.data<T>(),
262-
sizeof(T), 0);
263263
T state;
264-
memory::Copy(platform::CPUPlace(), &state, gpu_place, in_state.data<T>(),
265-
sizeof(T), 0);
266264
T scale;
265+
memory::Copy(platform::CPUPlace(), &accum, gpu_place, in_accum.data<T>(),
266+
sizeof(T), ctx.stream());
267+
memory::Copy(platform::CPUPlace(), &state, gpu_place, in_state.data<T>(),
268+
sizeof(T), ctx.stream());
267269
memory::Copy(platform::CPUPlace(), &scale, gpu_place, cur_scale, sizeof(T),
268-
0);
269-
270+
ctx.stream());
271+
ctx.Wait();
270272
state = rate * state + 1;
271273
accum = rate * accum + scale;
272274
scale = accum / state;
273275

274276
memory::Copy(gpu_place, out_accum->mutable_data<T>(gpu_place),
275-
platform::CPUPlace(), &accum, sizeof(T), 0);
277+
platform::CPUPlace(), &accum, sizeof(T), ctx.stream());
276278
memory::Copy(gpu_place, out_state->mutable_data<T>(gpu_place),
277-
platform::CPUPlace(), &state, sizeof(T), 0);
279+
platform::CPUPlace(), &state, sizeof(T), ctx.stream());
278280
memory::Copy(gpu_place, out_scale->mutable_data<T>(gpu_place),
279-
platform::CPUPlace(), &scale, sizeof(T), 0);
281+
platform::CPUPlace(), &scale, sizeof(T), ctx.stream());
282+
ctx.Wait();
280283
}
281284
};
282285

0 commit comments

Comments
 (0)