Skip to content

Commit 60e7ee0

Browse files
committed
refine concat_op
1 parent cf883d9 commit 60e7ee0

File tree

8 files changed

+559
-49
lines changed

8 files changed

+559
-49
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ op_library(save_op DEPS lod_tensor)
184184
op_library(load_op DEPS lod_tensor)
185185
op_library(save_combine_op DEPS lod_tensor)
186186
op_library(load_combine_op DEPS lod_tensor)
187+
op_library(concat_op DEPS concat_functor)
187188

188189
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
189190
foreach(src ${GENERAL_OPS})

paddle/fluid/operators/concat_op.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
100100
namespace ops = paddle::operators;
101101
REGISTER_OP_EX(concat, ops::ConcatOp, ops::ConcatOpMaker, concat_grad,
102102
ops::ConcatOpGrad, false)
103-
REGISTER_OP_CPU_KERNEL(concat,
104-
ops::ConcatKernel<paddle::platform::CPUPlace, float>)
105-
REGISTER_OP_CPU_KERNEL(concat_grad,
106-
ops::ConcatGradKernel<paddle::platform::CPUPlace, float>)
103+
REGISTER_OP_CPU_KERNEL(
104+
concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>)
105+
REGISTER_OP_CPU_KERNEL(
106+
concat_grad,
107+
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, float>)

paddle/fluid/operators/concat_op.h

Lines changed: 8 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include <utility>
1818
#include <vector>
1919
#include "paddle/fluid/framework/op_registry.h"
20+
#include "paddle/fluid/operators/math/concat.h"
2021
#include "paddle/fluid/operators/strided_memcpy.h"
2122

2223
namespace paddle {
@@ -27,55 +28,17 @@ class ConcatKernel : public framework::OpKernel<T> {
2728
public:
2829
void Compute(const framework::ExecutionContext& ctx) const override {
2930
auto ins = ctx.MultiInput<framework::Tensor>("X");
30-
auto* out = ctx.Output<framework::Tensor>("Out");
31+
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
3132
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
3233
auto place = ctx.GetPlace();
3334
out->mutable_data<T>(place);
34-
35-
auto out_stride = framework::stride_numel(out->dims());
36-
37-
size_t output_offset = 0;
38-
39-
// If axis >=1, copy to out immediately need to call many times
40-
// of cuda memcpy. Copy the input to cpu and do the stride copy,
41-
// then copy to gpu output.
42-
43-
if (platform::is_gpu_place(place) && axis >= 1) {
44-
platform::CPUPlace copy_place;
45-
auto& cpu_ctx = *platform::DeviceContextPool::Instance().Get(copy_place);
46-
framework::Tensor cpu_out;
47-
cpu_out.Resize(out->dims());
48-
cpu_out.mutable_data<T>(copy_place);
49-
auto& dev_ctx = ctx.device_context();
50-
std::vector<std::unique_ptr<framework::Tensor>> cpu_ins;
51-
for (auto* in : ins) {
52-
std::unique_ptr<framework::Tensor> cpu_in(new framework::Tensor);
53-
framework::TensorCopy(*in, copy_place, dev_ctx, cpu_in.get());
54-
cpu_ins.emplace_back(std::move(cpu_in));
55-
}
56-
// TODO(dzhwinter): overlap copy and compute stream
57-
// https://devblogs.nvidia.com/how-overlap-data-transfers-cuda-cc/
58-
dev_ctx.Wait();
59-
60-
for (auto& in : cpu_ins) {
61-
auto& cpu_in = *in.get();
62-
auto in_stride = framework::stride_numel(cpu_in.dims());
63-
64-
StridedNumelCopyWithAxis<T>(
65-
cpu_ctx, axis, cpu_out.data<T>() + output_offset, out_stride,
66-
cpu_in.data<T>(), in_stride, in_stride[axis]);
67-
output_offset += in_stride[axis];
68-
}
69-
framework::TensorCopy(cpu_out, place, dev_ctx, out);
70-
} else {
71-
for (auto* in : ins) {
72-
auto in_stride = framework::stride_numel(in->dims());
73-
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
74-
out->data<T>() + output_offset, out_stride,
75-
in->data<T>(), in_stride, in_stride[axis]);
76-
output_offset += in_stride[axis];
77-
}
35+
std::vector<framework::Tensor> inputs(ins.size());
36+
for (size_t j = 0; j < ins.size(); ++j) {
37+
inputs[j] = *ins[j];
7838
}
39+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
40+
paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor;
41+
concat_functor(dev_ctx, inputs, static_cast<int>(axis), out);
7942
}
8043
};
8144

paddle/fluid/operators/math/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ if(WITH_GPU)
2020
nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context)
2121
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
2222
nv_library(cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context)
23+
nv_library(concat_functor SRCS concat.cc concat.cu DEPS device_context tensor)
2324
else()
2425
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto)
2526
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
@@ -37,10 +38,12 @@ else()
3738
cc_library(unpooling SRCS unpooling.cc DEPS device_context)
3839
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
3940
cc_library(cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context)
41+
cc_library(concat_functor SRCS concat.cc DEPS device_context tensor)
4042
endif()
4143

4244
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
4345
cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor)
4446
cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
4547
cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor)
4648
cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding)
49+
cc_test(concat_test SRCS concat_test.cc DEPS concat_functor tensor)

paddle/fluid/operators/math/concat.cc

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/math/concat.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
namespace math {
20+
21+
/*
22+
* All tensors' dimension should be the same.
23+
*/
24+
template <typename T>
25+
class ConcatFunctor<platform::CPUDeviceContext, T> {
26+
public:
27+
void operator()(const platform::CPUDeviceContext& context,
28+
std::vector<framework::Tensor>& input, const int axis,
29+
framework::Tensor* output) {
30+
// assume the the max size of input is less than 8 and see the performance
31+
// save origin dim
32+
int num = input.size();
33+
std::vector<paddle::framework::DDim> origin_dim(num);
34+
// for (int j = 0; j < num; ++j) {
35+
// origin_dim[j] = input[j].dims();
36+
// }
37+
auto out_dim = output->dims();
38+
39+
// get the matrix size
40+
int rows = 1;
41+
auto dim_0 = input[0].dims();
42+
for (int i = 0; i < axis; ++i) {
43+
rows *= dim_0[i];
44+
}
45+
int cols = input[0].numel() / rows;
46+
int out_rows = rows, out_cols = 0;
47+
bool sameShape = true;
48+
49+
// reshape to matrix
50+
for (int i = 0; i < num; ++i) {
51+
int t_cols = input[i].numel() / rows;
52+
if (sameShape) {
53+
if (t_cols != cols) sameShape = false;
54+
}
55+
out_cols += t_cols;
56+
input[i].Resize({rows, t_cols});
57+
}
58+
output->Resize({out_rows, out_cols});
59+
auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
60+
// computation
61+
for (int k = 0; k < rows; ++k) {
62+
// offset k * out_cols
63+
T* dst_ptr = output->data<T>() + k * out_cols;
64+
int col_idx = 0;
65+
for (int j = 0; j < num; ++j) {
66+
int col_len = input[j].dims()[1];
67+
const T* src_prt = input[j].data<T>() + k * col_len;
68+
memory::Copy(cpu_place, dst_ptr + col_idx, cpu_place, src_prt,
69+
sizeof(T) * col_len);
70+
col_idx += col_len;
71+
}
72+
}
73+
74+
// recover origin dim
75+
// for (int j = 0; j < num; ++j) {
76+
// input[j]->Resize(origin_dim[j]);
77+
// }
78+
output->Resize(out_dim);
79+
}
80+
};
81+
82+
template class ConcatFunctor<platform::CPUDeviceContext, int>;
83+
template class ConcatFunctor<platform::CPUDeviceContext, int64_t>;
84+
template class ConcatFunctor<platform::CPUDeviceContext, float>;
85+
template class ConcatFunctor<platform::CPUDeviceContext, double>;
86+
87+
} // namespace math
88+
} // namespace operators
89+
} // namespace paddle

paddle/fluid/operators/math/concat.cu

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/math/concat.h"
16+
#include "paddle/fluid/platform/cuda_helper.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
namespace math {
21+
22+
// TODO(zcd): This can be replaced by tensor,
23+
// if that, maybe we should add int8 to VarType::Type.
24+
// Or replaced by tensorArray.
25+
static constexpr int MaxSize = 32;
26+
template <typename T>
27+
struct CUDADeviceArray {
28+
T data[MaxSize];
29+
int size;
30+
};
31+
32+
template <typename T>
33+
__device__ T upper_bound(const T* first, T count, T val) {
34+
const T* orig = first;
35+
const T* it = nullptr;
36+
T step = 0;
37+
while (count > 0) {
38+
it = first;
39+
step = count / 2;
40+
it += step;
41+
if (!(val < *it)) {
42+
first = ++it;
43+
count -= step + 1;
44+
} else {
45+
count = step;
46+
}
47+
}
48+
return first - orig;
49+
}
50+
51+
template <typename T>
52+
__global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
53+
const CUDADeviceArray<int> input_cols,
54+
const int output_rows, const int output_cols,
55+
T* output) {
56+
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
57+
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
58+
int segment = upper_bound<int>(input_cols.data, input_cols.size, tid_x) - 1;
59+
60+
int curr_offset = input_cols.data[segment];
61+
int curr_segment = segment;
62+
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
63+
T curr_col_offset;
64+
while ((curr_col_offset = input_cols.data[curr_segment + 1]) <= tid_x) {
65+
curr_offset = curr_col_offset;
66+
++curr_segment;
67+
}
68+
69+
int local_col = tid_x - curr_offset;
70+
int segment_width = curr_col_offset - curr_offset;
71+
const T* input_ptr = inputs.data[curr_segment];
72+
73+
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
74+
output[tid_y * output_cols + tid_x] =
75+
input_ptr[tid_y * segment_width + local_col];
76+
}
77+
}
78+
79+
/*
80+
* All tensors' dimension should be the same.
81+
*/
82+
template <typename T>
83+
class ConcatFunctor<platform::CUDADeviceContext, T> {
84+
public:
85+
void operator()(const platform::CUDADeviceContext& context,
86+
std::vector<framework::Tensor>& input, const int axis,
87+
framework::Tensor* output) {
88+
// assume the the max size of input is less than 8 and see the performance
89+
// save origin dim
90+
int num = input.size();
91+
// std::vector<paddle::framework::DDim> origin_dim(num);
92+
// for (int j = 0; j < num; ++j) {
93+
// origin_dim[j] = input[j].dims();
94+
// }
95+
auto out_dim = output->dims();
96+
97+
// get the matrix size
98+
int rows = 1;
99+
auto dim_0 = input[0].dims();
100+
for (int i = 0; i < axis; ++i) {
101+
rows *= dim_0[i];
102+
}
103+
int cols = input[0].numel() / rows;
104+
int out_rows = rows, out_cols = 0;
105+
bool sameShape = true;
106+
107+
CUDADeviceArray<const T*> inputs_data;
108+
CUDADeviceArray<int> inputs_cols;
109+
inputs_data.size = num;
110+
inputs_cols.size = num + 1;
111+
inputs_cols.data[0] = 0;
112+
// reshape to matrix
113+
// check input shape is valid
114+
for (int i = 0; i < num; ++i) {
115+
int t_cols = input[i].numel() / rows;
116+
if (sameShape) {
117+
if (t_cols != cols) sameShape = false;
118+
}
119+
out_cols += t_cols;
120+
input[i].Resize({rows, t_cols});
121+
inputs_cols.data[i + 1] = out_cols;
122+
inputs_data.data[i] = input[i].data<T>();
123+
}
124+
output->Resize({out_rows, out_cols});
125+
126+
// computation
127+
const int kThreadsPerBlock = 256;
128+
int block_cols = std::min(out_cols, kThreadsPerBlock);
129+
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
130+
dim3 block_size = dim3(block_cols, block_rows, 1);
131+
132+
int grid_cols = (out_cols + block_cols - 1) / block_cols;
133+
int grid_rows = (out_rows + block_rows - 1) / block_rows;
134+
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
135+
136+
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
137+
inputs_data, inputs_cols, out_rows, out_cols, output->data<T>());
138+
139+
// recover origin dim
140+
// for (int j = 0; j < num; ++j) {
141+
// input[j].Resize(origin_dim[j]);
142+
// }
143+
output->Resize(out_dim);
144+
}
145+
};
146+
147+
template class ConcatFunctor<platform::CUDADeviceContext, int>;
148+
template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
149+
template class ConcatFunctor<platform::CUDADeviceContext, float>;
150+
template class ConcatFunctor<platform::CUDADeviceContext, double>;
151+
152+
} // namespace math
153+
} // namespace operators
154+
} // namespace paddle

paddle/fluid/operators/math/concat.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/fluid/framework/tensor.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
namespace math {
21+
22+
/*
23+
* the tensor's shape of input will be changed,
24+
* so the second parameter is not const.
25+
*
26+
*/
27+
template <typename DeviceContext, typename T>
28+
class ConcatFunctor {
29+
public:
30+
void operator()(const DeviceContext& context,
31+
std::vector<framework::Tensor>& input, const int axis,
32+
framework::Tensor* output);
33+
};
34+
35+
} // namespace math
36+
} // namespace operators
37+
} // namespace paddle

0 commit comments

Comments
 (0)