Skip to content

Commit 17a076d

Browse files
committed
replace TensorCopy with TensorCopySync
1 parent fa61320 commit 17a076d

File tree

1 file changed

+7
-20
lines changed

1 file changed

+7
-20
lines changed

paddle/fluid/framework/data_device_transform.cc

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,14 @@ limitations under the License. */
1616
namespace paddle {
1717
namespace framework {
1818

19-
static const platform::DeviceContext* GetDeviceContext(
20-
const platform::Place& src_place, const platform::Place& dst_place) {
21-
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
22-
23-
if (platform::is_gpu_place(src_place) && platform::is_cpu_place(dst_place)) {
24-
return pool.Get(src_place);
25-
} else if (platform::is_cpu_place(src_place) &&
26-
platform::is_gpu_place(dst_place)) {
27-
return pool.Get(dst_place);
28-
} else {
29-
PADDLE_THROW(
30-
"Currently, model parallelism is only supported between CPU and CUDA");
31-
}
32-
}
33-
34-
void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
35-
Tensor* out) {
19+
void TransDataDevice(const Tensor &in, const platform::Place &dst_place,
20+
Tensor *out) {
3621
VLOG(3) << "DeviceTransform in, src_place " << in.place()
3722
<< " dst_place: " << dst_place;
38-
auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
23+
24+
PADDLE_ENFORCE_NE(
25+
in.place().which(), dst_place.which(),
26+
"Currently, model parallelism is only supported between CPU and CUDA");
3927

4028
// FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and
4129
// the enforced checkings have been done in GetDeviceContext, so the
@@ -46,8 +34,7 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
4634
// the transforming is from CPU to GPU and the number of elements is little.
4735
// But the embarrassment is that this solution this solution makes training
4836
// slower.
49-
TensorCopy(in, dst_place, *dev_ctx, out);
50-
dev_ctx->Wait();
37+
TensorCopySync(in, dst_place, out);
5138
}
5239

5340
} // namespace framework

0 commit comments

Comments
 (0)