Skip to content

Commit 84aea8a

Browse files
author
chengduo
authored
Merge pull request #8669 from chengduoZH/feature/concat_op
Refine concat_op
2 parents 8c71ada + c3864ea commit 84aea8a

File tree

12 files changed

+883
-50
lines changed

12 files changed

+883
-50
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ op_library(save_op DEPS lod_tensor)
201201
op_library(load_op DEPS lod_tensor)
202202
op_library(save_combine_op DEPS lod_tensor)
203203
op_library(load_combine_op DEPS lod_tensor)
204+
op_library(concat_op DEPS concat_functor)
204205

205206
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
206207
foreach(src ${GENERAL_OPS})

paddle/fluid/operators/concat_op.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
100100
namespace ops = paddle::operators;
101101
REGISTER_OP_EX(concat, ops::ConcatOp, ops::ConcatOpMaker, concat_grad,
102102
ops::ConcatOpGrad, false)
103-
REGISTER_OP_CPU_KERNEL(concat,
104-
ops::ConcatKernel<paddle::platform::CPUPlace, float>)
105-
REGISTER_OP_CPU_KERNEL(concat_grad,
106-
ops::ConcatGradKernel<paddle::platform::CPUPlace, float>)
103+
REGISTER_OP_CPU_KERNEL(
104+
concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>)
105+
REGISTER_OP_CPU_KERNEL(
106+
concat_grad,
107+
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, float>)

paddle/fluid/operators/concat_op.h

Lines changed: 38 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include <utility>
1818
#include <vector>
1919
#include "paddle/fluid/framework/op_registry.h"
20+
#include "paddle/fluid/operators/math/concat.h"
2021
#include "paddle/fluid/operators/strided_memcpy.h"
2122

2223
namespace paddle {
@@ -27,54 +28,30 @@ class ConcatKernel : public framework::OpKernel<T> {
2728
public:
2829
void Compute(const framework::ExecutionContext& ctx) const override {
2930
auto ins = ctx.MultiInput<framework::Tensor>("X");
30-
auto* out = ctx.Output<framework::Tensor>("Out");
31+
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
3132
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
3233
auto place = ctx.GetPlace();
3334
out->mutable_data<T>(place);
3435

35-
auto out_stride = framework::stride_numel(out->dims());
36-
37-
size_t output_offset = 0;
38-
39-
// If axis >=1, copy to out immediately need to call many times
40-
// of cuda memcpy. Copy the input to cpu and do the stride copy,
41-
// then copy to gpu output.
42-
43-
if (platform::is_gpu_place(place) && axis >= 1) {
44-
platform::CPUPlace copy_place;
45-
auto& cpu_ctx = *platform::DeviceContextPool::Instance().Get(copy_place);
46-
framework::Tensor cpu_out;
47-
cpu_out.Resize(out->dims());
48-
cpu_out.mutable_data<T>(copy_place);
49-
auto& dev_ctx = ctx.device_context();
50-
std::vector<std::unique_ptr<framework::Tensor>> cpu_ins;
51-
for (auto* in : ins) {
52-
std::unique_ptr<framework::Tensor> cpu_in(new framework::Tensor);
53-
framework::TensorCopy(*in, copy_place, dev_ctx, cpu_in.get());
54-
cpu_ins.emplace_back(std::move(cpu_in));
55-
}
56-
// TODO(dzhwinter): overlap copy and compute stream
57-
// https://devblogs.nvidia.com/how-overlap-data-transfers-cuda-cc/
58-
dev_ctx.Wait();
59-
60-
for (auto& in : cpu_ins) {
61-
auto& cpu_in = *in.get();
62-
auto in_stride = framework::stride_numel(cpu_in.dims());
63-
64-
StridedNumelCopyWithAxis<T>(
65-
cpu_ctx, axis, cpu_out.data<T>() + output_offset, out_stride,
66-
cpu_in.data<T>(), in_stride, in_stride[axis]);
67-
output_offset += in_stride[axis];
68-
}
69-
framework::TensorCopy(cpu_out, place, dev_ctx, out);
70-
} else {
36+
// Sometimes direct copies will be faster, this maybe need deeply analysis.
37+
if (axis == 0 && ins.size() < 10) {
38+
size_t output_offset = 0;
7139
for (auto* in : ins) {
7240
auto in_stride = framework::stride_numel(in->dims());
41+
auto out_stride = framework::stride_numel(out->dims());
7342
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
7443
out->data<T>() + output_offset, out_stride,
7544
in->data<T>(), in_stride, in_stride[axis]);
7645
output_offset += in_stride[axis];
7746
}
47+
} else {
48+
std::vector<framework::Tensor> inputs(ins.size());
49+
for (size_t j = 0; j < ins.size(); ++j) {
50+
inputs[j] = *ins[j];
51+
}
52+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
53+
paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor;
54+
concat_functor(dev_ctx, inputs, static_cast<int>(axis), out);
7855
}
7956
}
8057
};
@@ -86,16 +63,31 @@ class ConcatGradKernel : public framework::OpKernel<T> {
8663
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
8764
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
8865
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
89-
size_t input_offset = 0;
90-
auto in_stride = framework::stride_numel(in->dims());
9166

92-
for (auto& out : outs) {
93-
out->mutable_data<T>(ctx.GetPlace());
94-
auto out_stride = framework::stride_numel(out->dims());
95-
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
96-
out_stride, in->data<T>() + input_offset,
97-
in_stride, out_stride[axis]);
98-
input_offset += out_stride[axis];
67+
// Sometimes direct copies will be faster, this maybe need deeply analysis.
68+
if (axis == 0 && outs.size() < 10) {
69+
size_t input_offset = 0;
70+
auto in_stride = framework::stride_numel(in->dims());
71+
72+
for (auto& out : outs) {
73+
out->mutable_data<T>(ctx.GetPlace());
74+
auto out_stride = framework::stride_numel(out->dims());
75+
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
76+
out_stride, in->data<T>() + input_offset,
77+
in_stride, out_stride[axis]);
78+
input_offset += out_stride[axis];
79+
}
80+
} else {
81+
std::vector<framework::Tensor> outputs(outs.size());
82+
for (size_t j = 0; j < outs.size(); ++j) {
83+
outs[j]->mutable_data<T>(ctx.GetPlace());
84+
outputs[j] = *outs[j];
85+
}
86+
87+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
88+
paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
89+
concat_grad_functor;
90+
concat_grad_functor(dev_ctx, *in, static_cast<int>(axis), outputs);
9991
}
10092
}
10193
};

paddle/fluid/operators/math/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ if(WITH_GPU)
2020
nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context)
2121
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
2222
nv_library(cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context)
23+
nv_library(concat_functor SRCS concat.cc concat.cu DEPS device_context tensor)
2324
else()
2425
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto)
2526
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
@@ -37,10 +38,12 @@ else()
3738
cc_library(unpooling SRCS unpooling.cc DEPS device_context)
3839
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
3940
cc_library(cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context)
41+
cc_library(concat_functor SRCS concat.cc DEPS device_context tensor)
4042
endif()
4143

4244
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
4345
cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor)
4446
cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
4547
cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor)
4648
cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding)
49+
cc_test(concat_test SRCS concat_test.cc DEPS concat_functor tensor)

paddle/fluid/operators/math/concat.cc

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/math/concat.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
namespace math {
20+
21+
/*
22+
* All tensors' dimension should be the same and the values of
23+
* each dimension are the same, except the axis dimension.
24+
*/
25+
template <typename T>
26+
class ConcatFunctor<platform::CPUDeviceContext, T> {
27+
public:
28+
void operator()(const platform::CPUDeviceContext& context,
29+
const std::vector<framework::Tensor>& input, const int axis,
30+
framework::Tensor* output) {
31+
// TODO(zcd): Add input data validity checking
32+
int num = input.size();
33+
34+
int rows = 1;
35+
auto dim_0 = input[0].dims();
36+
for (int i = 0; i < axis; ++i) {
37+
rows *= dim_0[i];
38+
}
39+
int out_rows = rows, out_cols = 0;
40+
41+
std::vector<int64_t> input_cols(input.size());
42+
for (int i = 0; i < num; ++i) {
43+
int t_cols = input[i].numel() / rows;
44+
out_cols += t_cols;
45+
input_cols[i] = t_cols;
46+
}
47+
auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
48+
49+
// computation
50+
for (int k = 0; k < out_rows; ++k) {
51+
T* dst_ptr = output->data<T>() + k * out_cols;
52+
int col_idx = 0;
53+
for (int j = 0; j < num; ++j) {
54+
int col_len = input_cols[j];
55+
const T* src_prt = input[j].data<T>() + k * col_len;
56+
memory::Copy(cpu_place, dst_ptr + col_idx, cpu_place, src_prt,
57+
sizeof(T) * col_len);
58+
col_idx += col_len;
59+
}
60+
}
61+
}
62+
};
63+
64+
/*
65+
* All tensors' dimension should be the same and the values of
66+
* each dimension are the same, except the axis dimension.
67+
*/
68+
template <typename T>
69+
class ConcatGradFunctor<platform::CPUDeviceContext, T> {
70+
public:
71+
void operator()(const platform::CPUDeviceContext& context,
72+
const framework::Tensor& input, const int axis,
73+
std::vector<framework::Tensor>& outputs) {
74+
// TODO(zcd): Add input data validity checking
75+
int num = outputs.size();
76+
77+
int input_rows = 1;
78+
auto dim_0 = outputs[0].dims();
79+
for (int i = 0; i < axis; ++i) {
80+
input_rows *= dim_0[i];
81+
}
82+
int input_cols = 0;
83+
84+
std::vector<int64_t> output_cols(outputs.size());
85+
for (int i = 0; i < num; ++i) {
86+
int t_cols = outputs[i].numel() / input_rows;
87+
input_cols += t_cols;
88+
output_cols[i] = t_cols;
89+
}
90+
auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
91+
92+
// computation
93+
for (int k = 0; k < input_rows; ++k) {
94+
const T* src_ptr = input.data<T>() + k * input_cols;
95+
int col_idx = 0;
96+
for (int j = 0; j < num; ++j) {
97+
int col_len = output_cols[j];
98+
T* dst_ptr = outputs[j].data<T>() + k * col_len;
99+
memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx,
100+
sizeof(T) * col_len);
101+
col_idx += col_len;
102+
}
103+
}
104+
}
105+
};
106+
107+
template class ConcatFunctor<platform::CPUDeviceContext, int>;
108+
template class ConcatFunctor<platform::CPUDeviceContext, int64_t>;
109+
template class ConcatFunctor<platform::CPUDeviceContext, float>;
110+
template class ConcatFunctor<platform::CPUDeviceContext, double>;
111+
112+
template class ConcatGradFunctor<platform::CPUDeviceContext, int>;
113+
template class ConcatGradFunctor<platform::CPUDeviceContext, int64_t>;
114+
template class ConcatGradFunctor<platform::CPUDeviceContext, float>;
115+
template class ConcatGradFunctor<platform::CPUDeviceContext, double>;
116+
117+
} // namespace math
118+
} // namespace operators
119+
} // namespace paddle

0 commit comments

Comments
 (0)