@@ -25,12 +25,14 @@ void GetAccumulators<paddle::platform::CUDADeviceContext>(
25
25
auto * in_num_accumulates = ctx.Input <Tensor>(" in_num_accumulates" );
26
26
auto * in_num_updates = ctx.Input <Tensor>(" in_num_updates" );
27
27
auto stream = ctx.cuda_device_context ().stream ();
28
- memory::Copy (platform::CPUPlace (), old_num_accumulates_,
29
- platform::CUDAPlace (), in_old_num_accumulates->data <int64_t >(),
30
- sizeof (int64_t ), stream);
31
- memory::Copy (platform::CPUPlace (), num_accumulates_, platform::CUDAPlace (),
28
+ auto cuda_place =
29
+ boost::get<platform::CUDAPlace>(in_old_num_accumulates->place ());
30
+ memory::Copy (platform::CPUPlace (), old_num_accumulates_, cuda_place,
31
+ in_old_num_accumulates->data <int64_t >(), sizeof (int64_t ),
32
+ stream);
33
+ memory::Copy (platform::CPUPlace (), num_accumulates_, cuda_place,
32
34
in_num_accumulates->data <int64_t >(), sizeof (int64_t ), stream);
33
- memory::Copy (platform::CPUPlace (), num_updates_, platform::CUDAPlace () ,
35
+ memory::Copy (platform::CPUPlace (), num_updates_, cuda_place ,
34
36
in_num_updates->data <int64_t >(), sizeof (int64_t ), stream);
35
37
}
36
38
@@ -42,14 +44,16 @@ void SetAccumulators<paddle::platform::CUDADeviceContext>(
42
44
auto * out_old_num_accumulates = ctx.Output <Tensor>(" out_old_num_accumulates" );
43
45
auto * out_num_accumulates = ctx.Output <Tensor>(" out_num_accumulates" );
44
46
auto * out_num_updates = ctx.Output <Tensor>(" out_num_updates" );
47
+ auto cuda_place =
48
+ boost::get<platform::CUDAPlace>(out_old_num_accumulates->place ());
45
49
46
- memory::Copy (platform::CUDAPlace () , out_old_num_accumulates->data <int64_t >(),
50
+ memory::Copy (cuda_place , out_old_num_accumulates->data <int64_t >(),
47
51
platform::CPUPlace (), &old_num_accumulates_, sizeof (int64_t ),
48
52
stream);
49
- memory::Copy (platform::CUDAPlace () , out_num_accumulates->data <int64_t >(),
53
+ memory::Copy (cuda_place , out_num_accumulates->data <int64_t >(),
50
54
platform::CPUPlace (), &num_accumulates_, sizeof (int64_t ),
51
55
stream);
52
- memory::Copy (platform::CUDAPlace () , out_num_updates->data <int64_t >(),
56
+ memory::Copy (cuda_place , out_num_updates->data <int64_t >(),
53
57
platform::CPUPlace (), &num_updates_, sizeof (int64_t ), stream);
54
58
}
55
59
0 commit comments