@@ -12,23 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
+ #include " paddle/fluid/framework/mixed_vector.h"
15
16
#include " paddle/fluid/operators/math/concat.h"
16
17
#include " paddle/fluid/platform/cuda_helper.h"
17
18
18
19
namespace paddle {
19
20
namespace operators {
20
21
namespace math {
21
22
22
- // TODO(zcd): This can be replaced by tensor,
23
- // if that, maybe we should add int8 to VarType::Type.
24
- // Or replaced by tensorArray.
25
- static constexpr int MaxSize = 8 ;
26
- template <typename T>
27
- struct CUDADeviceArray {
28
- T data[MaxSize];
29
- int size;
30
- };
31
-
32
23
template <typename T>
33
24
__device__ T upper_bound (const T* first, T count, T val) {
34
25
const T* orig = first;
@@ -49,25 +40,24 @@ __device__ T upper_bound(const T* first, T count, T val) {
49
40
}
50
41
51
42
template <typename T>
52
- __global__ void KernelConcat (const CUDADeviceArray<const T*> inputs,
53
- const CUDADeviceArray<int > input_cols,
43
+ __global__ void KernelConcat (T** inputs, const int * input_cols, int col_size,
54
44
const int output_rows, const int output_cols,
55
45
T* output) {
56
46
int tid_x = blockIdx .x * blockDim .x + threadIdx .x ;
57
- int segment = upper_bound<int >(input_cols. data , input_cols. size , tid_x) - 1 ;
47
+ int segment = upper_bound<int >(input_cols, col_size , tid_x) - 1 ;
58
48
59
- int curr_offset = input_cols. data [segment];
49
+ int curr_offset = input_cols[segment];
60
50
int curr_segment = segment;
61
51
for (; tid_x < output_cols; tid_x += blockDim .x * gridDim .x ) {
62
52
T curr_col_offset;
63
- while ((curr_col_offset = input_cols. data [curr_segment + 1 ]) <= tid_x) {
53
+ while ((curr_col_offset = input_cols[curr_segment + 1 ]) <= tid_x) {
64
54
curr_offset = curr_col_offset;
65
55
++curr_segment;
66
56
}
67
57
68
58
int local_col = tid_x - curr_offset;
69
59
int segment_width = curr_col_offset - curr_offset;
70
- const T* input_ptr = inputs. data [curr_segment];
60
+ T* input_ptr = inputs[curr_segment];
71
61
int tid_y = blockIdx .y * blockDim .y + threadIdx .y ;
72
62
for (; tid_y < output_rows; tid_y += blockDim .y * gridDim .y )
73
63
output[tid_y * output_cols + tid_x] =
@@ -76,41 +66,41 @@ __global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
76
66
}
77
67
78
68
template <typename T>
79
- __global__ void KernelConcat (const CUDADeviceArray< const T*> inputs,
80
- const int input_col , const int output_rows ,
81
- const int output_cols, T* output) {
69
+ __global__ void KernelConcat (T** inputs, const int input_col ,
70
+ const int output_rows , const int output_cols ,
71
+ T* output) {
82
72
int tid_x = blockIdx .x * blockDim .x + threadIdx .x ;
83
73
float inv_input_col = 1.0 / input_col;
84
74
for (; tid_x < output_cols; tid_x += blockDim .x * gridDim .x ) {
85
75
int split = tid_x * inv_input_col;
86
76
int in_offset = tid_x - split * input_col;
87
- const T* input_ptr = inputs. data [split];
77
+ T* input_ptr = inputs[split];
88
78
int tid_y = blockIdx .y * blockDim .y + threadIdx .y ;
89
- for (; tid_y < output_rows; tid_y += blockDim .y * gridDim .y )
79
+ for (; tid_y < output_rows; tid_y += blockDim .y * gridDim .y ) {
90
80
output[tid_y * output_cols + tid_x] =
91
81
input_ptr[tid_y * input_col + in_offset];
82
+ }
92
83
}
93
84
}
94
85
95
86
template <typename T>
96
87
__global__ void KernelConcatGrad (const T* input, const int input_row,
97
- const int input_col,
98
- CUDADeviceArray<int > output_cols,
99
- CUDADeviceArray<T*> outputs) {
88
+ const int input_col, const int * output_cols,
89
+ int col_size, T** outputs) {
100
90
int tid_x = blockIdx .x * blockDim .x + threadIdx .x ;
101
- int segment = upper_bound<int >(output_cols. data , output_cols. size , tid_x) - 1 ;
102
- int curr_offset = output_cols. data [segment];
91
+ int segment = upper_bound<int >(output_cols, col_size , tid_x) - 1 ;
92
+ int curr_offset = output_cols[segment];
103
93
int curr_segment = segment;
104
94
for (; tid_x < input_col; tid_x += blockDim .x * gridDim .x ) {
105
95
T curr_col_offset;
106
- while ((curr_col_offset = output_cols. data [curr_segment + 1 ]) <= tid_x) {
96
+ while ((curr_col_offset = output_cols[curr_segment + 1 ]) <= tid_x) {
107
97
curr_offset = curr_col_offset;
108
98
++curr_segment;
109
99
}
110
100
111
101
int local_col = tid_x - curr_offset;
112
102
int segment_width = curr_col_offset - curr_offset;
113
- T* output_ptr = outputs. data [curr_segment];
103
+ T* output_ptr = outputs[curr_segment];
114
104
int tid_y = blockIdx .y * blockDim .y + threadIdx .y ;
115
105
for (; tid_y < input_row; tid_y += blockDim .y * gridDim .y )
116
106
output_ptr[tid_y * segment_width + local_col] =
@@ -121,13 +111,13 @@ __global__ void KernelConcatGrad(const T* input, const int input_row,
121
111
template <typename T>
122
112
__global__ void KernelConcatGrad (const T* input, const int input_row,
123
113
const int input_col, const int output_cols,
124
- CUDADeviceArray<T*> outputs) {
114
+ T** outputs) {
125
115
int tid_x = blockIdx .x * blockDim .x + threadIdx .x ;
126
116
float inv_input_col = 1.0 / input_col;
127
117
for (; tid_x < input_col; tid_x += blockDim .x * gridDim .x ) {
128
118
int split = tid_x * inv_input_col;
129
119
int in_offset = tid_x - split * input_col;
130
- T* output_ptr = outputs. data [split];
120
+ T* output_ptr = outputs[split];
131
121
int tid_y = blockIdx .y * blockDim .y + threadIdx .y ;
132
122
for (; tid_y < input_row; tid_y += blockDim .y * gridDim .y )
133
123
output_ptr[tid_y * output_cols + in_offset] =
@@ -136,46 +126,45 @@ __global__ void KernelConcatGrad(const T* input, const int input_row,
136
126
}
137
127
138
128
/*
139
- * All tensors' dimension should be the same.
129
+ * All tensors' dimension should be the same and the values of
130
+ * each dimension are the same, except the axis dimension.
140
131
*/
141
132
template <typename T>
142
133
class ConcatFunctor <platform::CUDADeviceContext, T> {
143
134
public:
144
135
void operator ()(const platform::CUDADeviceContext& context,
145
136
const std::vector<framework::Tensor>& input, const int axis,
146
137
framework::Tensor* output) {
147
- // assume the the max size of input is less than 8 and see the performance
148
- // save origin dim
138
+ // TODO(zcd): Add input data validity checking
149
139
int num = input.size ();
150
- PADDLE_ENFORCE_LT (num, MaxSize, " input number should be less than %d" ,
151
- MaxSize);
152
- // get the matrix size
153
140
int rows = 1 ;
154
141
auto dim_0 = input[0 ].dims ();
155
142
for (int i = 0 ; i < axis; ++i) {
156
143
rows *= dim_0[i];
157
144
}
158
145
int cols = input[0 ].numel () / rows;
159
146
int out_rows = rows, out_cols = 0 ;
160
- bool sameShape = true ;
161
147
162
- CUDADeviceArray<const T*> inputs_data;
163
- CUDADeviceArray<int > inputs_cols;
164
- inputs_data.size = num;
165
- inputs_cols.size = num + 1 ;
166
- inputs_cols.data [0 ] = 0 ;
167
- // reshape to matrix
168
- // check input shape is valid
148
+ paddle::framework::Vector<int16_t > inputs_data (num * sizeof (T*) / 2 );
149
+ paddle::framework::Vector<int > inputs_cols (num + 1 );
150
+ inputs_cols[0 ] = 0 ;
151
+ T** inputs_ptr = reinterpret_cast <T**>(inputs_data.data ());
152
+
153
+ bool sameShape = true ;
169
154
for (int i = 0 ; i < num; ++i) {
170
155
int t_cols = input[i].numel () / rows;
171
156
if (sameShape) {
172
157
if (t_cols != cols) sameShape = false ;
173
158
}
174
159
out_cols += t_cols;
175
- inputs_cols. data [i + 1 ] = out_cols;
176
- inputs_data. data [i] = input[i].data <T>();
160
+ inputs_cols[i + 1 ] = out_cols;
161
+ inputs_ptr [i] = const_cast <T*>( input[i].data <T>() );
177
162
}
178
163
164
+ T** ins_gpu =
165
+ reinterpret_cast <T**>(inputs_data.CUDAMutableData (context.GetPlace ()));
166
+ const int * ins_col_gpu = inputs_cols.CUDAData (context.GetPlace ());
167
+
179
168
// computation
180
169
// set the thread block and grid according to CurrentDeviceId
181
170
const int kThreadsPerBlock = 1024 ;
@@ -198,27 +187,27 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
198
187
199
188
if (sameShape) {
200
189
KernelConcat<<<grid_size, block_size, 0 , context.stream()>>> (
201
- inputs_data , cols, out_rows, out_cols, output->data <T>());
190
+ ins_gpu , cols, out_rows, out_cols, output->data <T>());
202
191
} else {
203
192
KernelConcat<<<grid_size, block_size, 0 , context.stream()>>> (
204
- inputs_data, inputs_cols, out_rows, out_cols, output->data <T>());
193
+ ins_gpu, ins_col_gpu, static_cast <int >(inputs_cols.size ()), out_rows,
194
+ out_cols, output->data <T>());
205
195
}
206
196
}
207
197
};
208
198
199
+ /*
200
+ * All tensors' dimension should be the same and the values of
201
+ * each dimension are the same, except the axis dimension.
202
+ */
209
203
template <typename T>
210
204
class ConcatGradFunctor <platform::CUDADeviceContext, T> {
211
205
public:
212
206
void operator ()(const platform::CUDADeviceContext& context,
213
207
const framework::Tensor& input, const int axis,
214
208
std::vector<framework::Tensor>& outputs) {
215
- // assume the the max size of input is less than 8 and see the performance
216
- // save origin dim
209
+ // TODO(zcd): Add input data validity checking
217
210
int num = outputs.size ();
218
- PADDLE_ENFORCE_LT (num, MaxSize, " input number should be less than %d" ,
219
- MaxSize);
220
-
221
- // get the matrix size
222
211
int input_row = 1 ;
223
212
auto dim_0 = outputs[0 ].dims ();
224
213
for (int i = 0 ; i < axis; ++i) {
@@ -229,24 +218,27 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
229
218
int input_col = 0 ;
230
219
bool sameShape = true ;
231
220
232
- CUDADeviceArray<T*> outputs_data;
233
- CUDADeviceArray<int > outputs_cols;
234
- outputs_data.size = num;
235
- outputs_cols.size = num + 1 ;
236
- outputs_cols.data [0 ] = 0 ;
221
+ paddle::framework::Vector<int16_t > outputs_data (num * sizeof (T*) / 2 );
222
+ paddle::framework::Vector<int > outputs_cols (num + 1 );
223
+ outputs_cols[0 ] = 0 ;
224
+ T** outputs_ptr = reinterpret_cast <T**>(outputs_data.data ());
237
225
238
226
for (int i = 0 ; i < num; ++i) {
239
227
int t_col = outputs[i].numel () / input_row;
240
228
if (sameShape) {
241
229
if (t_col != output_col_0) sameShape = false ;
242
230
}
243
231
input_col += t_col;
244
- outputs_cols. data [i + 1 ] = input_col;
245
- outputs_data. data [i] = outputs[i].data <T>();
232
+ outputs_cols[i + 1 ] = input_col;
233
+ outputs_ptr [i] = outputs[i].data <T>();
246
234
}
247
235
236
+ T** outs_gpu =
237
+ reinterpret_cast <T**>(outputs_data.CUDAMutableData (context.GetPlace ()));
238
+ const int * outs_col_gpu = outputs_cols.CUDAData (context.GetPlace ());
239
+
248
240
// computation
249
- const int kThreadsPerBlock = 256 ;
241
+ const int kThreadsPerBlock = 1024 ;
250
242
int block_cols = std::min (input_col, kThreadsPerBlock );
251
243
int block_rows = std::max (kThreadsPerBlock / block_cols, 1 );
252
244
dim3 block_size = dim3 (block_cols, block_rows, 1 );
@@ -257,10 +249,11 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
257
249
258
250
if (sameShape) {
259
251
KernelConcatGrad<<<grid_size, block_size, 0 , context.stream()>>> (
260
- input.data <T>(), input_row, input_col, output_col_0, outputs_data );
252
+ input.data <T>(), input_row, input_col, output_col_0, outs_gpu );
261
253
} else {
262
254
KernelConcatGrad<<<grid_size, block_size, 0 , context.stream()>>> (
263
- input.data <T>(), input_row, input_col, outputs_cols, outputs_data);
255
+ input.data <T>(), input_row, input_col, outs_col_gpu,
256
+ static_cast <int >(outputs_cols.size ()), outs_gpu);
264
257
}
265
258
}
266
259
};
0 commit comments