Skip to content

Commit baa9f50

Browse files
committed
fix errors in multiplex_op
1 parent 2e61733 commit baa9f50

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

paddle/fluid/operators/multiplex_op.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class MultiplexGPUKernel : public framework::OpKernel<T> {
3333
auto cols = ins[0]->numel() / rows;
3434
// copy index to cpu
3535
Tensor index_t_cpu;
36-
TensorCopy(*ids, platform::CPUPlace(), ctx.device_context(), &index_t_cpu);
36+
TensorCopySync(*ids, platform::CPUPlace(), &index_t_cpu);
3737
auto* index = index_t_cpu.data<int32_t>();
3838
auto stream = ctx.cuda_device_context().stream();
3939
platform::CUDAPlace place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
@@ -69,7 +69,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel<T> {
6969
auto cols = ins[0]->numel() / rows;
7070
// copy index to cpu
7171
Tensor index_t_cpu;
72-
TensorCopy(*ids, platform::CPUPlace(), ctx.device_context(), &index_t_cpu);
72+
TensorCopySync(*ids, platform::CPUPlace(), &index_t_cpu);
7373
auto* index = index_t_cpu.data<int32_t>();
7474

7575
auto stream = ctx.cuda_device_context().stream();

0 commit comments

Comments
 (0)