Skip to content

Commit e0e5453

Browse files
committed
refine the code
1 parent c976fac commit e0e5453

File tree

3 files changed

+68
-74
lines changed

3 files changed

+68
-74
lines changed

paddle/operators/concat_op.h

Lines changed: 7 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include <vector>
1818
#include "paddle/framework/ddim.h"
1919
#include "paddle/framework/op_registry.h"
20+
#include "paddle/operators/strided_memcpy.h"
2021

2122
namespace paddle {
2223
namespace operators {
@@ -32,34 +33,13 @@ class ConcatKernel : public framework::OpKernel<T> {
3233
out->mutable_data<T>(place);
3334

3435
auto out_stride = framework::stride_numel(out->dims());
35-
int64_t before = out_stride[0] / out_stride[axis];
36-
int64_t out_after = out_stride[axis];
3736

3837
size_t output_offset = 0;
3938
for (auto* in : ins) {
4039
auto in_stride = framework::stride_numel(in->dims());
41-
int64_t in_after = in_stride[axis];
42-
for (int64_t i = 0; i < before; ++i) {
43-
if (platform::is_cpu_place(place)) {
44-
auto& cpu_place = boost::get<platform::CPUPlace>(place);
45-
memory::Copy(
46-
cpu_place, out->data<T>() + output_offset + i * out_after,
47-
cpu_place, in->data<T>() + i * in_after, sizeof(T) * in_after);
48-
} else {
49-
#ifdef PADDLE_WITH_CUDA
50-
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
51-
auto& cuda_ctx =
52-
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
53-
memory::Copy(gpu_place, out->data<T>() +
54-
output_offset + i * out_after,
55-
gpu_place, in->data<T>() + i * in_after,
56-
sizeof(T) * in_after, cuda_ctx.stream()));
57-
#else
58-
PADDLE_THROW("Paddle is not compiled with GPU");
59-
#endif
60-
}
61-
}
62-
output_offset += in_after;
40+
StridedNumelCopyWithAxis<T>(ctx, axis, out->data<T>() + output_offset,
41+
out_stride, in->data<T>(), in_stride);
42+
output_offset += in_stride[axis];
6343
}
6444
}
6545
};
@@ -73,35 +53,13 @@ class ConcatGradKernel : public framework::OpKernel<T> {
7353
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
7454
size_t input_offset = 0;
7555
auto in_stride = framework::stride_numel(in->dims());
76-
auto place = ctx.GetPlace();
7756

78-
// numel before the specified axis
79-
int64_t before = in_stride[0] / in_stride[axis];
80-
int64_t in_after = in_stride[axis];
8157
for (auto& out : outs) {
8258
out->mutable_data<T>(ctx.GetPlace());
8359
auto out_stride = framework::stride_numel(out->dims());
84-
int64_t out_after = out_stride[axis];
85-
for (int64_t i = 0; i < before; ++i) {
86-
if (platform::is_cpu_place(place)) {
87-
auto& cpu_place = boost::get<platform::CPUPlace>(place);
88-
memory::Copy(cpu_place, out->data<T>() + i * out_after, cpu_place,
89-
in->data<T>() + input_offset + i * in_after,
90-
sizeof(T) * out_after);
91-
} else {
92-
#ifdef PADDLE_WITH_CUDA
93-
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
94-
auto& cuda_ctx =
95-
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
96-
memory::Copy(gpu_place, out->data<T>() + i * out_after, gpu_place,
97-
in->data<T>() + input_offset + i * in_after,
98-
sizeof(T) * out_after, cuda_ctx.stream());
99-
#else
100-
PADDLE_THROW("Paddle is not compiled with GPU");
101-
#endif
102-
}
103-
}
104-
input_offset += out_after;
60+
StridedNumelCopyWithAxis<T>(ctx, axis, out->data<T>(), out_stride,
61+
in->data<T>() + input_offset, in_stride);
62+
input_offset += out_stride[axis];
10563
}
10664
}
10765
};

paddle/operators/split_op.h

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License. */
1818
#include <vector>
1919
#include "paddle/framework/ddim.h"
2020
#include "paddle/framework/op_registry.h"
21+
#include "paddle/operators/strided_memcpy.h"
2122

2223
namespace paddle {
2324
namespace operators {
@@ -26,41 +27,19 @@ template <typename DeviceContext, typename T>
2627
class SplitOpKernel : public framework::OpKernel<T> {
2728
public:
2829
void Compute(const framework::ExecutionContext& ctx) const override {
29-
// auto start = std::chrono::steady_clock::now();
3030
auto* in = ctx.Input<framework::Tensor>("X");
3131
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
3232
auto in_stride = framework::stride_numel(in->dims());
3333
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
3434
auto place = ctx.GetPlace();
3535

36-
// numel before the specified axis
37-
int64_t before = in_stride[0] / in_stride[axis];
38-
int64_t in_after = in_stride[axis];
3936
size_t input_offset = 0;
4037
for (auto& out : outs) {
4138
out->mutable_data<T>(ctx.GetPlace());
4239
auto out_stride = framework::stride_numel(out->dims());
43-
int64_t out_after = out_stride[axis];
44-
for (int64_t i = 0; i < before; ++i) {
45-
if (platform::is_cpu_place(place)) {
46-
auto& cpu_place = boost::get<platform::CPUPlace>(place);
47-
memory::Copy(cpu_place, out->data<T>() + i * out_after, cpu_place,
48-
in->data<T>() + input_offset + i * in_after,
49-
sizeof(T) * out_after);
50-
} else {
51-
#ifdef PADDLE_WITH_CUDA
52-
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
53-
auto& cuda_ctx =
54-
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
55-
memory::Copy(gpu_place, out->data<T>() + i * out_after, gpu_place,
56-
in->data<T>() + input_offset + i * in_after,
57-
sizeof(T) * out_after, cuda_ctx.stream());
58-
#else
59-
PADDLE_THROW("Paddle is not compiled with GPU");
60-
#endif
61-
}
62-
}
63-
input_offset += out_after;
40+
StridedNumelCopyWithAxis<T>(ctx, axis, out->data<T>(), out_stride,
41+
in->data<T>() + input_offset, in_stride);
42+
input_offset += out_stride[axis];
6443
}
6544
}
6645
};

paddle/operators/strided_memcpy.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,62 @@ inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, const T* src,
4141
StridedCopyDimVisitor<T> func(dev_ctx, src, src_stride, dst_stride, dst);
4242
boost::apply_visitor(func, dst_dim);
4343
}
44+
45+
// Strided numel memory copy from src to dst by the specified axis
46+
//
47+
// For example, for a tensor dims [4, 20, 100], the strieded numel is
48+
// [8000, 2000, 100]
49+
//
50+
// NOTE: The src and dst tensor should have the same elements
51+
// except the specified axis.
52+
template <typename T>
53+
inline void StridedNumelCopyWithAxis(const framework::ExecutionContext& ctx,
54+
int64_t axis, T* dst,
55+
const framework::DDim& dst_stride_numel,
56+
const T* src,
57+
const framework::DDim& src_stride_numel) {
58+
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
59+
int64_t src_after = src_stride_numel[axis];
60+
int64_t dst_after = dst_stride_numel[axis];
61+
auto place = ctx.GetPlace();
62+
63+
PADDLE_ENFORCE_EQ(src_stride_numel.size(), dst_stride_numel.size(),
64+
"src and dst tensor should have the same dims size.");
65+
66+
for (int64_t i = 0; i < axis; ++i) {
67+
if (i < axis) {
68+
PADDLE_ENFORCE_EQ(src_stride_numel[i] / src_stride_numel[axis],
69+
dst_stride_numel[i] / dst_stride_numel[axis],
70+
"src and dst should have the same elements "
71+
"except the specified axis.");
72+
} else if (i == axis) {
73+
continue;
74+
} else {
75+
PADDLE_ENFORCE_EQ(src_stride_numel[i], dst_stride_numel[i],
76+
"src and dst should have the same elements "
77+
"except the specified axis.");
78+
}
79+
}
80+
81+
for (int64_t i = 0; i < before; ++i) {
82+
if (platform::is_cpu_place(place)) {
83+
auto& cpu_place = boost::get<platform::CPUPlace>(place);
84+
memory::Copy(cpu_place, dst + i * dst_after, cpu_place,
85+
src + i * src_after, sizeof(T) * src_after);
86+
} else {
87+
#ifdef PADDLE_WITH_CUDA
88+
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
89+
auto& cuda_ctx =
90+
reinterpret_cast<const platform::CUDADeviceContext&>(ctx);
91+
memory::Copy(cpu_place, dst + i * dst_after, cpu_place,
92+
src + i * src_after, sizeof(T) * src_after,
93+
cuda_ctx.stream());
94+
#else
95+
PADDLE_THROW("Paddle is not compiled with GPU");
96+
#endif
97+
}
98+
}
99+
}
100+
44101
} // namespace operators
45102
} // namespace paddle

0 commit comments

Comments
 (0)