Skip to content

Commit 82bd82c

Browse files
committed
follow comments and refine code
1 parent 00e596e commit 82bd82c

File tree

4 files changed

+88
-75
lines changed

4 files changed

+88
-75
lines changed

paddle/fluid/operators/concat_op.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class ConcatKernel : public framework::OpKernel<T> {
3333
auto place = ctx.GetPlace();
3434
out->mutable_data<T>(place);
3535

36+
// TODO(zcd): Sometimes direct copies will be faster
3637
std::vector<framework::Tensor> inputs(ins.size());
3738
for (size_t j = 0; j < ins.size(); ++j) {
3839
inputs[j] = *ins[j];
@@ -51,6 +52,7 @@ class ConcatGradKernel : public framework::OpKernel<T> {
5152
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
5253
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
5354

55+
// TODO(zcd): Sometimes direct copies will be faster
5456
std::vector<framework::Tensor> outputs(outs.size());
5557
for (size_t j = 0; j < outs.size(); ++j) {
5658
outs[j]->mutable_data<T>(ctx.GetPlace());

paddle/fluid/operators/math/concat.cc

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,25 @@ namespace operators {
1919
namespace math {
2020

2121
/*
22-
* All tensors' dimension should be the same.
22+
* All tensors' dimension should be the same and the values of
23+
* each dimension are the same, except the axis dimension.
2324
*/
2425
template <typename T>
2526
class ConcatFunctor<platform::CPUDeviceContext, T> {
2627
public:
2728
void operator()(const platform::CPUDeviceContext& context,
2829
const std::vector<framework::Tensor>& input, const int axis,
2930
framework::Tensor* output) {
30-
// assume the the max size of input is less than 8 and see the performance
31-
// save origin dim
31+
// TODO(zcd): Add input data validity checking
3232
int num = input.size();
33-
std::vector<paddle::framework::DDim> origin_dim(num);
3433

35-
// get the matrix size
3634
int rows = 1;
3735
auto dim_0 = input[0].dims();
3836
for (int i = 0; i < axis; ++i) {
3937
rows *= dim_0[i];
4038
}
4139
int out_rows = rows, out_cols = 0;
4240

43-
// get input's cols
4441
std::vector<int64_t> input_cols(input.size());
4542
for (int i = 0; i < num; ++i) {
4643
int t_cols = input[i].numel() / rows;
@@ -64,26 +61,26 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
6461
}
6562
};
6663

64+
/*
65+
* All tensors' dimension should be the same and the values of
66+
* each dimension are the same, except the axis dimension.
67+
*/
6768
template <typename T>
6869
class ConcatGradFunctor<platform::CPUDeviceContext, T> {
6970
public:
7071
void operator()(const platform::CPUDeviceContext& context,
7172
const framework::Tensor& input, const int axis,
7273
std::vector<framework::Tensor>& outputs) {
73-
// assume the the max size of input is less than 8 and see the performance
74-
// save origin dim
74+
// TODO(zcd): Add input data validity checking
7575
int num = outputs.size();
76-
std::vector<paddle::framework::DDim> origin_dim(num);
7776

78-
// get the matrix size
7977
int input_rows = 1;
8078
auto dim_0 = outputs[0].dims();
8179
for (int i = 0; i < axis; ++i) {
8280
input_rows *= dim_0[i];
8381
}
8482
int input_cols = 0;
8583

86-
// get outputs' cols
8784
std::vector<int64_t> output_cols(outputs.size());
8885
for (int i = 0; i < num; ++i) {
8986
int t_cols = outputs[i].numel() / input_rows;

paddle/fluid/operators/math/concat.cu

Lines changed: 57 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include "paddle/fluid/framework/mixed_vector.h"
1516
#include "paddle/fluid/operators/math/concat.h"
1617
#include "paddle/fluid/platform/cuda_helper.h"
1718

1819
namespace paddle {
1920
namespace operators {
2021
namespace math {
2122

22-
// TODO(zcd): This can be replaced by tensor,
23-
// if that, maybe we should add int8 to VarType::Type.
24-
// Or replaced by tensorArray.
25-
static constexpr int MaxSize = 8;
26-
template <typename T>
27-
struct CUDADeviceArray {
28-
T data[MaxSize];
29-
int size;
30-
};
31-
3223
template <typename T>
3324
__device__ T upper_bound(const T* first, T count, T val) {
3425
const T* orig = first;
@@ -49,25 +40,24 @@ __device__ T upper_bound(const T* first, T count, T val) {
4940
}
5041

5142
template <typename T>
52-
__global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
53-
const CUDADeviceArray<int> input_cols,
43+
__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size,
5444
const int output_rows, const int output_cols,
5545
T* output) {
5646
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
57-
int segment = upper_bound<int>(input_cols.data, input_cols.size, tid_x) - 1;
47+
int segment = upper_bound<int>(input_cols, col_size, tid_x) - 1;
5848

59-
int curr_offset = input_cols.data[segment];
49+
int curr_offset = input_cols[segment];
6050
int curr_segment = segment;
6151
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
6252
T curr_col_offset;
63-
while ((curr_col_offset = input_cols.data[curr_segment + 1]) <= tid_x) {
53+
while ((curr_col_offset = input_cols[curr_segment + 1]) <= tid_x) {
6454
curr_offset = curr_col_offset;
6555
++curr_segment;
6656
}
6757

6858
int local_col = tid_x - curr_offset;
6959
int segment_width = curr_col_offset - curr_offset;
70-
const T* input_ptr = inputs.data[curr_segment];
60+
T* input_ptr = inputs[curr_segment];
7161
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
7262
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
7363
output[tid_y * output_cols + tid_x] =
@@ -76,41 +66,41 @@ __global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
7666
}
7767

7868
template <typename T>
79-
__global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
80-
const int input_col, const int output_rows,
81-
const int output_cols, T* output) {
69+
__global__ void KernelConcat(T** inputs, const int input_col,
70+
const int output_rows, const int output_cols,
71+
T* output) {
8272
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
8373
float inv_input_col = 1.0 / input_col;
8474
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
8575
int split = tid_x * inv_input_col;
8676
int in_offset = tid_x - split * input_col;
87-
const T* input_ptr = inputs.data[split];
77+
T* input_ptr = inputs[split];
8878
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
89-
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
79+
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) {
9080
output[tid_y * output_cols + tid_x] =
9181
input_ptr[tid_y * input_col + in_offset];
82+
}
9283
}
9384
}
9485

9586
template <typename T>
9687
__global__ void KernelConcatGrad(const T* input, const int input_row,
97-
const int input_col,
98-
CUDADeviceArray<int> output_cols,
99-
CUDADeviceArray<T*> outputs) {
88+
const int input_col, const int* output_cols,
89+
int col_size, T** outputs) {
10090
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
101-
int segment = upper_bound<int>(output_cols.data, output_cols.size, tid_x) - 1;
102-
int curr_offset = output_cols.data[segment];
91+
int segment = upper_bound<int>(output_cols, col_size, tid_x) - 1;
92+
int curr_offset = output_cols[segment];
10393
int curr_segment = segment;
10494
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
10595
T curr_col_offset;
106-
while ((curr_col_offset = output_cols.data[curr_segment + 1]) <= tid_x) {
96+
while ((curr_col_offset = output_cols[curr_segment + 1]) <= tid_x) {
10797
curr_offset = curr_col_offset;
10898
++curr_segment;
10999
}
110100

111101
int local_col = tid_x - curr_offset;
112102
int segment_width = curr_col_offset - curr_offset;
113-
T* output_ptr = outputs.data[curr_segment];
103+
T* output_ptr = outputs[curr_segment];
114104
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
115105
for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
116106
output_ptr[tid_y * segment_width + local_col] =
@@ -121,13 +111,13 @@ __global__ void KernelConcatGrad(const T* input, const int input_row,
121111
template <typename T>
122112
__global__ void KernelConcatGrad(const T* input, const int input_row,
123113
const int input_col, const int output_cols,
124-
CUDADeviceArray<T*> outputs) {
114+
T** outputs) {
125115
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
126116
float inv_input_col = 1.0 / input_col;
127117
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
128118
int split = tid_x * inv_input_col;
129119
int in_offset = tid_x - split * input_col;
130-
T* output_ptr = outputs.data[split];
120+
T* output_ptr = outputs[split];
131121
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
132122
for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
133123
output_ptr[tid_y * output_cols + in_offset] =
@@ -136,46 +126,45 @@ __global__ void KernelConcatGrad(const T* input, const int input_row,
136126
}
137127

138128
/*
139-
* All tensors' dimension should be the same.
129+
* All tensors' dimension should be the same and the values of
130+
* each dimension are the same, except the axis dimension.
140131
*/
141132
template <typename T>
142133
class ConcatFunctor<platform::CUDADeviceContext, T> {
143134
public:
144135
void operator()(const platform::CUDADeviceContext& context,
145136
const std::vector<framework::Tensor>& input, const int axis,
146137
framework::Tensor* output) {
147-
// assume the the max size of input is less than 8 and see the performance
148-
// save origin dim
138+
// TODO(zcd): Add input data validity checking
149139
int num = input.size();
150-
PADDLE_ENFORCE_LT(num, MaxSize, "input number should be less than %d",
151-
MaxSize);
152-
// get the matrix size
153140
int rows = 1;
154141
auto dim_0 = input[0].dims();
155142
for (int i = 0; i < axis; ++i) {
156143
rows *= dim_0[i];
157144
}
158145
int cols = input[0].numel() / rows;
159146
int out_rows = rows, out_cols = 0;
160-
bool sameShape = true;
161147

162-
CUDADeviceArray<const T*> inputs_data;
163-
CUDADeviceArray<int> inputs_cols;
164-
inputs_data.size = num;
165-
inputs_cols.size = num + 1;
166-
inputs_cols.data[0] = 0;
167-
// reshape to matrix
168-
// check input shape is valid
148+
paddle::framework::Vector<int16_t> inputs_data(num * sizeof(T*) / 2);
149+
paddle::framework::Vector<int> inputs_cols(num + 1);
150+
inputs_cols[0] = 0;
151+
T** inputs_ptr = reinterpret_cast<T**>(inputs_data.data());
152+
153+
bool sameShape = true;
169154
for (int i = 0; i < num; ++i) {
170155
int t_cols = input[i].numel() / rows;
171156
if (sameShape) {
172157
if (t_cols != cols) sameShape = false;
173158
}
174159
out_cols += t_cols;
175-
inputs_cols.data[i + 1] = out_cols;
176-
inputs_data.data[i] = input[i].data<T>();
160+
inputs_cols[i + 1] = out_cols;
161+
inputs_ptr[i] = const_cast<T*>(input[i].data<T>());
177162
}
178163

164+
T** ins_gpu =
165+
reinterpret_cast<T**>(inputs_data.CUDAMutableData(context.GetPlace()));
166+
const int* ins_col_gpu = inputs_cols.CUDAData(context.GetPlace());
167+
179168
// computation
180169
// set the thread block and grid according to CurrentDeviceId
181170
const int kThreadsPerBlock = 1024;
@@ -198,27 +187,27 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
198187

199188
if (sameShape) {
200189
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
201-
inputs_data, cols, out_rows, out_cols, output->data<T>());
190+
ins_gpu, cols, out_rows, out_cols, output->data<T>());
202191
} else {
203192
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
204-
inputs_data, inputs_cols, out_rows, out_cols, output->data<T>());
193+
ins_gpu, ins_col_gpu, static_cast<int>(inputs_cols.size()), out_rows,
194+
out_cols, output->data<T>());
205195
}
206196
}
207197
};
208198

199+
/*
200+
* All tensors' dimension should be the same and the values of
201+
* each dimension are the same, except the axis dimension.
202+
*/
209203
template <typename T>
210204
class ConcatGradFunctor<platform::CUDADeviceContext, T> {
211205
public:
212206
void operator()(const platform::CUDADeviceContext& context,
213207
const framework::Tensor& input, const int axis,
214208
std::vector<framework::Tensor>& outputs) {
215-
// assume the the max size of input is less than 8 and see the performance
216-
// save origin dim
209+
// TODO(zcd): Add input data validity checking
217210
int num = outputs.size();
218-
PADDLE_ENFORCE_LT(num, MaxSize, "input number should be less than %d",
219-
MaxSize);
220-
221-
// get the matrix size
222211
int input_row = 1;
223212
auto dim_0 = outputs[0].dims();
224213
for (int i = 0; i < axis; ++i) {
@@ -229,24 +218,27 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
229218
int input_col = 0;
230219
bool sameShape = true;
231220

232-
CUDADeviceArray<T*> outputs_data;
233-
CUDADeviceArray<int> outputs_cols;
234-
outputs_data.size = num;
235-
outputs_cols.size = num + 1;
236-
outputs_cols.data[0] = 0;
221+
paddle::framework::Vector<int16_t> outputs_data(num * sizeof(T*) / 2);
222+
paddle::framework::Vector<int> outputs_cols(num + 1);
223+
outputs_cols[0] = 0;
224+
T** outputs_ptr = reinterpret_cast<T**>(outputs_data.data());
237225

238226
for (int i = 0; i < num; ++i) {
239227
int t_col = outputs[i].numel() / input_row;
240228
if (sameShape) {
241229
if (t_col != output_col_0) sameShape = false;
242230
}
243231
input_col += t_col;
244-
outputs_cols.data[i + 1] = input_col;
245-
outputs_data.data[i] = outputs[i].data<T>();
232+
outputs_cols[i + 1] = input_col;
233+
outputs_ptr[i] = outputs[i].data<T>();
246234
}
247235

236+
T** outs_gpu =
237+
reinterpret_cast<T**>(outputs_data.CUDAMutableData(context.GetPlace()));
238+
const int* outs_col_gpu = outputs_cols.CUDAData(context.GetPlace());
239+
248240
// computation
249-
const int kThreadsPerBlock = 256;
241+
const int kThreadsPerBlock = 1024;
250242
int block_cols = std::min(input_col, kThreadsPerBlock);
251243
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
252244
dim3 block_size = dim3(block_cols, block_rows, 1);
@@ -257,10 +249,11 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
257249

258250
if (sameShape) {
259251
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
260-
input.data<T>(), input_row, input_col, output_col_0, outputs_data);
252+
input.data<T>(), input_row, input_col, output_col_0, outs_gpu);
261253
} else {
262254
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
263-
input.data<T>(), input_row, input_col, outputs_cols, outputs_data);
255+
input.data<T>(), input_row, input_col, outs_col_gpu,
256+
static_cast<int>(outputs_cols.size()), outs_gpu);
264257
}
265258
}
266259
};

paddle/fluid/operators/math/concat.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,16 @@ namespace operators {
2020
namespace math {
2121

2222
/*
23+
* \brief Concatenate the input tensors along the dimension axis.
24+
* TODO(zcd): maybe it needs to be more detailed.
25+
* Examples:
26+
* Input[0] = [[1,2],[3,4]]
27+
* Input[1] = [[5,6]]
28+
* axis = 0
2329
*
30+
* Output = [[1,2],
31+
* [3,4],
32+
* [5,6]]
2433
*/
2534
template <typename DeviceContext, typename T>
2635
class ConcatFunctor {
@@ -30,6 +39,18 @@ class ConcatFunctor {
3039
framework::Tensor* output);
3140
};
3241

42+
/*
43+
* \brief Split the input tensors along the dimension axis into outputs.
44+
* TODO(zcd): maybe it needs to be more detailed.
45+
* Examples:
46+
* Input = [[1,2],
47+
* [3,4],
48+
* [5,6]]
49+
* axis = 0
50+
*
51+
* Output[0] = [[1,2],[3,4]]
52+
* Output[1] = [[5,6]]
53+
*/
3354
template <typename DeviceContext, typename T>
3455
class ConcatGradFunctor {
3556
public:

0 commit comments

Comments
 (0)