@@ -16,26 +16,14 @@ limitations under the License. */
16
16
namespace paddle {
17
17
namespace framework {
18
18
19
- static const platform::DeviceContext* GetDeviceContext (
20
- const platform::Place& src_place, const platform::Place& dst_place) {
21
- platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance ();
22
-
23
- if (platform::is_gpu_place (src_place) && platform::is_cpu_place (dst_place)) {
24
- return pool.Get (src_place);
25
- } else if (platform::is_cpu_place (src_place) &&
26
- platform::is_gpu_place (dst_place)) {
27
- return pool.Get (dst_place);
28
- } else {
29
- PADDLE_THROW (
30
- " Currently, model parallelism is only supported between CPU and CUDA" );
31
- }
32
- }
33
-
34
- void TransDataDevice (const Tensor& in, const platform::Place& dst_place,
35
- Tensor* out) {
19
+ void TransDataDevice (const Tensor &in, const platform::Place &dst_place,
20
+ Tensor *out) {
36
21
VLOG (3 ) << " DeviceTransform in, src_place " << in.place ()
37
22
<< " dst_place: " << dst_place;
38
- auto * dev_ctx = GetDeviceContext (in.place (), dst_place);
23
+
24
+ PADDLE_ENFORCE_NE (
25
+ in.place ().which (), dst_place.which (),
26
+ " Currently, model parallelism is only supported between CPU and CUDA" );
39
27
40
28
// FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and
41
29
// the enforced checkings have been done in GetDeviceContext, so the
@@ -46,8 +34,7 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
46
34
// the transforming is from CPU to GPU and the number of elements is little.
47
35
// But the embarrassment is that this solution this solution makes training
48
36
// slower.
49
- TensorCopy (in, dst_place, *dev_ctx, out);
50
- dev_ctx->Wait ();
37
+ TensorCopySync (in, dst_place, out);
51
38
}
52
39
53
40
} // namespace framework
0 commit comments