@@ -14,6 +14,7 @@ limitations under the License. */
14
14
15
15
#pragma once
16
16
17
+ #include < utility>
17
18
#include < vector>
18
19
#include " paddle/fluid/framework/op_registry.h"
19
20
#include " paddle/fluid/operators/strided_memcpy.h"
@@ -34,12 +35,46 @@ class ConcatKernel : public framework::OpKernel<T> {
34
35
auto out_stride = framework::stride_numel (out->dims ());
35
36
36
37
size_t output_offset = 0 ;
37
- for (auto * in : ins) {
38
- auto in_stride = framework::stride_numel (in->dims ());
39
- StridedNumelCopyWithAxis<T>(ctx.device_context (), axis,
40
- out->data <T>() + output_offset, out_stride,
41
- in->data <T>(), in_stride, in_stride[axis]);
42
- output_offset += in_stride[axis];
38
+
39
+ // If axis >=1, copy to out immediately need to call many times
40
+ // of cuda memcpy. Copy the input to cpu and do the stride copy,
41
+ // then copy to gpu output.
42
+
43
+ if (platform::is_gpu_place (place) && axis >= 1 ) {
44
+ platform::CPUPlace copy_place;
45
+ auto & cpu_ctx = *platform::DeviceContextPool::Instance ().Get (copy_place);
46
+ framework::Tensor cpu_out;
47
+ cpu_out.Resize (out->dims ());
48
+ cpu_out.mutable_data <T>(copy_place);
49
+ auto & dev_ctx = ctx.device_context ();
50
+ std::vector<std::unique_ptr<framework::Tensor>> cpu_ins;
51
+ for (auto * in : ins) {
52
+ std::unique_ptr<framework::Tensor> cpu_in (new framework::Tensor);
53
+ framework::TensorCopy (*in, copy_place, dev_ctx, cpu_in.get ());
54
+ cpu_ins.emplace_back (std::move (cpu_in));
55
+ }
56
+ // TODO(dzhwinter): overlap copy and compute stream
57
+ // https://devblogs.nvidia.com/how-overlap-data-transfers-cuda-cc/
58
+ dev_ctx.Wait ();
59
+
60
+ for (auto & in : cpu_ins) {
61
+ auto & cpu_in = *in.get ();
62
+ auto in_stride = framework::stride_numel (cpu_in.dims ());
63
+
64
+ StridedNumelCopyWithAxis<T>(
65
+ cpu_ctx, axis, cpu_out.data <T>() + output_offset, out_stride,
66
+ cpu_in.data <T>(), in_stride, in_stride[axis]);
67
+ output_offset += in_stride[axis];
68
+ }
69
+ framework::TensorCopy (cpu_out, place, dev_ctx, out);
70
+ } else {
71
+ for (auto * in : ins) {
72
+ auto in_stride = framework::stride_numel (in->dims ());
73
+ StridedNumelCopyWithAxis<T>(ctx.device_context (), axis,
74
+ out->data <T>() + output_offset, out_stride,
75
+ in->data <T>(), in_stride, in_stride[axis]);
76
+ output_offset += in_stride[axis];
77
+ }
43
78
}
44
79
}
45
80
};
0 commit comments