Skip to content

Commit 525a4fd

Browse files
author
Yancey
authored
Merge pull request #8270 from Yancey1989/improve_concat_split_op
Improve split and concat op
2 parents b56f4a4 + 2353325 commit 525a4fd

File tree

6 files changed

+102
-32
lines changed

6 files changed

+102
-32
lines changed

paddle/fluid/framework/ddim.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,5 +314,15 @@ DDim stride(const DDim& ddim) {
314314
}
315315
return framework::make_ddim(strides);
316316
}
317+
318+
DDim stride_numel(const framework::DDim& ddim) {
319+
std::vector<int64_t> strides(ddim.size());
320+
strides[ddim.size() - 1] = ddim[ddim.size() - 1];
321+
for (int i = ddim.size() - 2; i >= 0; --i) {
322+
strides[i] = strides[i + 1] * ddim[i];
323+
}
324+
return framework::make_ddim(strides);
325+
}
326+
317327
} // namespace framework
318328
} // namespace paddle

paddle/fluid/framework/ddim.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims);
125125
DDim flatten_to_1d(const DDim& src);
126126

127127
DDim stride(const DDim& ddim);
128+
129+
DDim stride_numel(const DDim& ddim);
128130
} // namespace framework
129131
} // namespace paddle
130132

paddle/fluid/operators/concat_op.h

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,18 @@ class ConcatKernel : public framework::OpKernel<T> {
2828
auto ins = ctx.MultiInput<framework::Tensor>("X");
2929
auto* out = ctx.Output<framework::Tensor>("Out");
3030
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
31-
const size_t n = ins.size();
31+
auto place = ctx.GetPlace();
32+
out->mutable_data<T>(place);
33+
34+
auto out_stride = framework::stride_numel(out->dims());
35+
3236
size_t output_offset = 0;
33-
out->mutable_data<T>(ctx.GetPlace());
34-
auto out_stride = framework::stride(out->dims());
35-
for (size_t i = 0; i < n; i++) {
36-
auto& in = ins[i];
37-
auto axis_dim = in->dims()[axis];
38-
auto in_stride = framework::stride(in->dims());
39-
StridedMemcpy<T>(ctx.device_context(), in->data<T>(), in_stride,
40-
in->dims(), out_stride, out->data<T>() + output_offset);
41-
output_offset += axis_dim * in_stride[axis];
37+
for (auto* in : ins) {
38+
auto in_stride = framework::stride_numel(in->dims());
39+
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
40+
out->data<T>() + output_offset, out_stride,
41+
in->data<T>(), in_stride);
42+
output_offset += in_stride[axis];
4243
}
4344
}
4445
};
@@ -50,17 +51,16 @@ class ConcatGradKernel : public framework::OpKernel<T> {
5051
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
5152
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
5253
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
53-
const size_t n = outs.size();
5454
size_t input_offset = 0;
55-
auto in_stride = framework::stride(in->dims());
56-
for (size_t i = 0; i < n; i++) {
57-
auto& out = outs[i];
55+
auto in_stride = framework::stride_numel(in->dims());
56+
57+
for (auto& out : outs) {
5858
out->mutable_data<T>(ctx.GetPlace());
59-
size_t axis_dim = out->dims()[axis];
60-
auto out_stride = framework::stride(out->dims());
61-
StridedMemcpy<T>(ctx.device_context(), in->data<T>() + input_offset,
62-
in_stride, out->dims(), out_stride, out->data<T>());
63-
input_offset += axis_dim * in_stride[axis];
59+
auto out_stride = framework::stride_numel(out->dims());
60+
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
61+
out_stride, in->data<T>() + input_offset,
62+
in_stride);
63+
input_offset += out_stride[axis];
6464
}
6565
}
6666
};

paddle/fluid/operators/split_op.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <chrono>
1718
#include <vector>
1819
#include "paddle/fluid/framework/op_registry.h"
1920
#include "paddle/fluid/operators/strided_memcpy.h"
@@ -27,18 +28,18 @@ class SplitOpKernel : public framework::OpKernel<T> {
2728
void Compute(const framework::ExecutionContext& ctx) const override {
2829
auto* in = ctx.Input<framework::Tensor>("X");
2930
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
30-
auto in_stride = framework::stride(in->dims());
31+
auto in_stride = framework::stride_numel(in->dims());
3132
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
32-
const size_t n = outs.size();
33+
auto place = ctx.GetPlace();
34+
3335
size_t input_offset = 0;
34-
for (size_t i = 0; i < n; i++) {
35-
auto& out = outs[i];
36+
for (auto& out : outs) {
3637
out->mutable_data<T>(ctx.GetPlace());
37-
size_t axis_dim = out->dims()[axis];
38-
auto out_stride = framework::stride(out->dims());
39-
StridedMemcpy<T>(ctx.device_context(), in->data<T>() + input_offset,
40-
in_stride, out->dims(), out_stride, out->data<T>());
41-
input_offset += axis_dim * in_stride[axis];
38+
auto out_stride = framework::stride_numel(out->dims());
39+
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
40+
out_stride, in->data<T>() + input_offset,
41+
in_stride);
42+
input_offset += out_stride[axis];
4243
}
4344
}
4445
};

paddle/fluid/operators/strided_memcpy.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,62 @@ inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, const T* src,
4141
StridedCopyDimVisitor<T> func(dev_ctx, src, src_stride, dst_stride, dst);
4242
boost::apply_visitor(func, dst_dim);
4343
}
44+
45+
// Strided numel memory copy from src to dst by the specified axis
46+
//
47+
// For example, for a tensor dims [4, 20, 100], the strieded numel is
48+
// [8000, 2000, 100]
49+
//
50+
// NOTE: The src and dst tensor should have the same elements
51+
// except the specified axis.
52+
template <typename T>
53+
inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
54+
int64_t axis, T* dst,
55+
const framework::DDim& dst_stride_numel,
56+
const T* src,
57+
const framework::DDim& src_stride_numel) {
58+
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
59+
int64_t src_after = src_stride_numel[axis];
60+
int64_t dst_after = dst_stride_numel[axis];
61+
auto place = ctx.GetPlace();
62+
63+
PADDLE_ENFORCE_EQ(src_stride_numel.size(), dst_stride_numel.size(),
64+
"src and dst tensor should have the same dims size.");
65+
66+
for (int64_t i = 0; i < axis; ++i) {
67+
if (i < axis) {
68+
PADDLE_ENFORCE_EQ(src_stride_numel[i] / src_stride_numel[axis],
69+
dst_stride_numel[i] / dst_stride_numel[axis],
70+
"src and dst should have the same elements "
71+
"except the specified axis.");
72+
} else if (i == axis) {
73+
continue;
74+
} else {
75+
PADDLE_ENFORCE_EQ(src_stride_numel[i], dst_stride_numel[i],
76+
"src and dst should have the same elements "
77+
"except the specified axis.");
78+
}
79+
}
80+
81+
for (int64_t i = 0; i < before; ++i) {
82+
if (platform::is_cpu_place(place)) {
83+
auto& cpu_place = boost::get<platform::CPUPlace>(place);
84+
memory::Copy(cpu_place, dst + i * dst_after, cpu_place,
85+
src + i * src_after, sizeof(T) * src_after);
86+
} else {
87+
#ifdef PADDLE_WITH_CUDA
88+
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
89+
auto& cuda_ctx =
90+
reinterpret_cast<const platform::CUDADeviceContext&>(ctx);
91+
memory::Copy(gpu_place, dst + i * dst_after, gpu_place,
92+
src + i * src_after, sizeof(T) * src_after,
93+
cuda_ctx.stream());
94+
#else
95+
PADDLE_THROW("Paddle is not compiled with GPU");
96+
#endif
97+
}
98+
}
99+
}
100+
44101
} // namespace operators
45102
} // namespace paddle

python/paddle/v2/fluid/tests/test_split_op.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
class TestSplitOp(OpTest):
2121
def setUp(self):
2222
self.op_type = "split"
23-
axis = 0
24-
x = np.random.random((4, 2, 5)).astype('float32')
25-
out = np.split(x, [1, 3], axis)
23+
axis = 1
24+
x = np.random.random((4, 5, 6)).astype('float32')
25+
out = np.split(x, [2, 3], axis)
2626
self.inputs = {'X': x}
27-
self.attrs = {'axis': axis, 'sections': [1, 2, 1]}
27+
self.attrs = {'axis': axis, 'sections': [2, 1, 2]}
2828
self.outputs = {'Out': [('out%d' % i, out[i]) \
2929
for i in xrange(len(out))]}
3030

0 commit comments

Comments
 (0)