Skip to content

Commit fa61320

Browse files
committed
update
1 parent 4bfadcd commit fa61320

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

paddle/fluid/framework/data_device_transform.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,17 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
3737
<< " dst_place: " << dst_place;
3838
auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
3939

40+
// FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and
41+
// the enforced checkings have been done in GetDeviceContext, so the
42+
// `dev_ctx->Wait()` is necessary. But `dev_ctx->Wait()` will make the program
43+
// slow, especially when the number of elements is little, for example,
44+
// the elements of learning rate are one and it's CPU side.
45+
// One solution is to use a CUDA kernel to complete the copy operation when
46+
// the transforming is from CPU to GPU and the number of elements is little.
47+
// But the embarrassment is that this solution this solution makes training
48+
// slower.
4049
TensorCopy(in, dst_place, *dev_ctx, out);
41-
42-
if (in.place().which() != dst_place.which()) {
43-
dev_ctx->Wait();
44-
}
50+
dev_ctx->Wait();
4551
}
4652

4753
} // namespace framework

0 commit comments

Comments
 (0)