@@ -16,31 +16,25 @@ limitations under the License. */
16
16
namespace paddle {
17
17
namespace framework {
18
18
19
- static const platform::DeviceContext* GetDeviceContext (
20
- const platform::Place& src_place, const platform::Place& dst_place) {
21
- platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance ();
22
-
23
- if (platform::is_gpu_place (src_place) && platform::is_cpu_place (dst_place)) {
24
- return pool.Get (src_place);
25
- } else if (platform::is_cpu_place (src_place) &&
26
- platform::is_gpu_place (dst_place)) {
27
- return pool.Get (dst_place);
28
- } else {
29
- PADDLE_THROW (
30
- " Currently, model parallelism is only supported between CPU and CUDA" );
31
- }
32
- }
33
-
34
- void TransDataDevice (const Tensor& in, const platform::Place& dst_place,
35
- Tensor* out) {
19
+ void TransDataDevice (const Tensor &in, const platform::Place &dst_place,
20
+ Tensor *out) {
36
21
VLOG (3 ) << " DeviceTransform in, src_place " << in.place ()
37
22
<< " dst_place: " << dst_place;
38
- auto * dev_ctx = GetDeviceContext (in.place (), dst_place);
39
23
40
- TensorCopy (in, dst_place, *dev_ctx, out);
41
- if (platform::is_gpu_place (in.place ()) && platform::is_cpu_place (dst_place)) {
42
- dev_ctx->Wait ();
43
- }
24
+ PADDLE_ENFORCE_NE (
25
+ in.place ().which (), dst_place.which (),
26
+ " Currently, model parallelism is only supported between CPU and CUDA" );
27
+
28
+ // FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and
29
+ // the enforced checkings have been done in GetDeviceContext, so the
30
+ // `dev_ctx->Wait()` is necessary. But `dev_ctx->Wait()` will make the program
31
+ // slow, especially when the number of elements is little, for example,
32
+ // the elements of learning rate are one and it's CPU side.
33
+ // One solution is to use a CUDA kernel to complete the copy operation when
34
+ // the transforming is from CPU to GPU and the number of elements is little.
35
+ // But the embarrassment is that this solution this solution makes training
36
+ // slower.
37
+ TensorCopySync (in, dst_place, out);
44
38
}
45
39
46
40
} // namespace framework
0 commit comments