@@ -17,6 +17,7 @@ limitations under the License. */
17
17
#include < vector>
18
18
#include " paddle/framework/ddim.h"
19
19
#include " paddle/framework/op_registry.h"
20
+ #include " paddle/operators/strided_memcpy.h"
20
21
21
22
namespace paddle {
22
23
namespace operators {
@@ -32,34 +33,13 @@ class ConcatKernel : public framework::OpKernel<T> {
32
33
out->mutable_data <T>(place);
33
34
34
35
auto out_stride = framework::stride_numel (out->dims ());
35
- int64_t before = out_stride[0 ] / out_stride[axis];
36
- int64_t out_after = out_stride[axis];
37
36
38
37
size_t output_offset = 0 ;
39
38
for (auto * in : ins) {
40
39
auto in_stride = framework::stride_numel (in->dims ());
41
- int64_t in_after = in_stride[axis];
42
- for (int64_t i = 0 ; i < before; ++i) {
43
- if (platform::is_cpu_place (place)) {
44
- auto & cpu_place = boost::get<platform::CPUPlace>(place);
45
- memory::Copy (
46
- cpu_place, out->data <T>() + output_offset + i * out_after,
47
- cpu_place, in->data <T>() + i * in_after, sizeof (T) * in_after);
48
- } else {
49
- #ifdef PADDLE_WITH_CUDA
50
- auto & gpu_place = boost::get<platform::CUDAPlace>(place);
51
- auto & cuda_ctx =
52
- reinterpret_cast <const platform::CUDADeviceContext&>(dev_ctx);
53
- memory::Copy (gpu_place, out->data <T>() +
54
- output_offset + i * out_after,
55
- gpu_place, in->data <T>() + i * in_after,
56
- sizeof (T) * in_after, cuda_ctx.stream ()));
57
- #else
58
- PADDLE_THROW (" Paddle is not compiled with GPU" );
59
- #endif
60
- }
61
- }
62
- output_offset += in_after;
40
+ StridedNumelCopyWithAxis<T>(ctx, axis, out->data <T>() + output_offset,
41
+ out_stride, in->data <T>(), in_stride);
42
+ output_offset += in_stride[axis];
63
43
}
64
44
}
65
45
};
@@ -73,35 +53,13 @@ class ConcatGradKernel : public framework::OpKernel<T> {
73
53
int64_t axis = static_cast <int64_t >(ctx.Attr <int >(" axis" ));
74
54
size_t input_offset = 0 ;
75
55
auto in_stride = framework::stride_numel (in->dims ());
76
- auto place = ctx.GetPlace ();
77
56
78
- // numel before the specified axis
79
- int64_t before = in_stride[0 ] / in_stride[axis];
80
- int64_t in_after = in_stride[axis];
81
57
for (auto & out : outs) {
82
58
out->mutable_data <T>(ctx.GetPlace ());
83
59
auto out_stride = framework::stride_numel (out->dims ());
84
- int64_t out_after = out_stride[axis];
85
- for (int64_t i = 0 ; i < before; ++i) {
86
- if (platform::is_cpu_place (place)) {
87
- auto & cpu_place = boost::get<platform::CPUPlace>(place);
88
- memory::Copy (cpu_place, out->data <T>() + i * out_after, cpu_place,
89
- in->data <T>() + input_offset + i * in_after,
90
- sizeof (T) * out_after);
91
- } else {
92
- #ifdef PADDLE_WITH_CUDA
93
- auto & gpu_place = boost::get<platform::CUDAPlace>(place);
94
- auto & cuda_ctx =
95
- reinterpret_cast <const platform::CUDADeviceContext&>(dev_ctx);
96
- memory::Copy (gpu_place, out->data <T>() + i * out_after, gpu_place,
97
- in->data <T>() + input_offset + i * in_after,
98
- sizeof (T) * out_after, cuda_ctx.stream ());
99
- #else
100
- PADDLE_THROW (" Paddle is not compiled with GPU" );
101
- #endif
102
- }
103
- }
104
- input_offset += out_after;
60
+ StridedNumelCopyWithAxis<T>(ctx, axis, out->data <T>(), out_stride,
61
+ in->data <T>() + input_offset, in_stride);
62
+ input_offset += out_stride[axis];
105
63
}
106
64
}
107
65
};
0 commit comments