Skip to content

Commit 00e596e

Browse files
committed
get max threads of GPU
1 parent 60e7ee0 commit 00e596e

File tree

7 files changed

+320
-55
lines changed

7 files changed

+320
-55
lines changed

paddle/fluid/operators/concat_op.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class ConcatKernel : public framework::OpKernel<T> {
3232
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
3333
auto place = ctx.GetPlace();
3434
out->mutable_data<T>(place);
35+
3536
std::vector<framework::Tensor> inputs(ins.size());
3637
for (size_t j = 0; j < ins.size(); ++j) {
3738
inputs[j] = *ins[j];
@@ -49,17 +50,17 @@ class ConcatGradKernel : public framework::OpKernel<T> {
4950
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
5051
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
5152
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
52-
size_t input_offset = 0;
53-
auto in_stride = framework::stride_numel(in->dims());
5453

55-
for (auto& out : outs) {
56-
out->mutable_data<T>(ctx.GetPlace());
57-
auto out_stride = framework::stride_numel(out->dims());
58-
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
59-
out_stride, in->data<T>() + input_offset,
60-
in_stride, out_stride[axis]);
61-
input_offset += out_stride[axis];
54+
std::vector<framework::Tensor> outputs(outs.size());
55+
for (size_t j = 0; j < outs.size(); ++j) {
56+
outs[j]->mutable_data<T>(ctx.GetPlace());
57+
outputs[j] = *outs[j];
6258
}
59+
60+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
61+
paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
62+
concat_grad_functor;
63+
concat_grad_functor(dev_ctx, *in, static_cast<int>(axis), outputs);
6364
}
6465
};
6566

paddle/fluid/operators/math/concat.cc

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,57 +25,85 @@ template <typename T>
2525
class ConcatFunctor<platform::CPUDeviceContext, T> {
2626
public:
2727
void operator()(const platform::CPUDeviceContext& context,
28-
std::vector<framework::Tensor>& input, const int axis,
28+
const std::vector<framework::Tensor>& input, const int axis,
2929
framework::Tensor* output) {
3030
// assume the the max size of input is less than 8 and see the performance
3131
// save origin dim
3232
int num = input.size();
3333
std::vector<paddle::framework::DDim> origin_dim(num);
34-
// for (int j = 0; j < num; ++j) {
35-
// origin_dim[j] = input[j].dims();
36-
// }
37-
auto out_dim = output->dims();
3834

3935
// get the matrix size
4036
int rows = 1;
4137
auto dim_0 = input[0].dims();
4238
for (int i = 0; i < axis; ++i) {
4339
rows *= dim_0[i];
4440
}
45-
int cols = input[0].numel() / rows;
4641
int out_rows = rows, out_cols = 0;
47-
bool sameShape = true;
4842

49-
// reshape to matrix
43+
// get input's cols
44+
std::vector<int64_t> input_cols(input.size());
5045
for (int i = 0; i < num; ++i) {
5146
int t_cols = input[i].numel() / rows;
52-
if (sameShape) {
53-
if (t_cols != cols) sameShape = false;
54-
}
5547
out_cols += t_cols;
56-
input[i].Resize({rows, t_cols});
48+
input_cols[i] = t_cols;
5749
}
58-
output->Resize({out_rows, out_cols});
5950
auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
51+
6052
// computation
61-
for (int k = 0; k < rows; ++k) {
62-
// offset k * out_cols
53+
for (int k = 0; k < out_rows; ++k) {
6354
T* dst_ptr = output->data<T>() + k * out_cols;
6455
int col_idx = 0;
6556
for (int j = 0; j < num; ++j) {
66-
int col_len = input[j].dims()[1];
57+
int col_len = input_cols[j];
6758
const T* src_prt = input[j].data<T>() + k * col_len;
6859
memory::Copy(cpu_place, dst_ptr + col_idx, cpu_place, src_prt,
6960
sizeof(T) * col_len);
7061
col_idx += col_len;
7162
}
7263
}
64+
}
65+
};
66+
67+
template <typename T>
68+
class ConcatGradFunctor<platform::CPUDeviceContext, T> {
69+
public:
70+
void operator()(const platform::CPUDeviceContext& context,
71+
const framework::Tensor& input, const int axis,
72+
std::vector<framework::Tensor>& outputs) {
73+
// assume the the max size of input is less than 8 and see the performance
74+
// save origin dim
75+
int num = outputs.size();
76+
std::vector<paddle::framework::DDim> origin_dim(num);
7377

74-
// recover origin dim
75-
// for (int j = 0; j < num; ++j) {
76-
// input[j]->Resize(origin_dim[j]);
77-
// }
78-
output->Resize(out_dim);
78+
// get the matrix size
79+
int input_rows = 1;
80+
auto dim_0 = outputs[0].dims();
81+
for (int i = 0; i < axis; ++i) {
82+
input_rows *= dim_0[i];
83+
}
84+
int input_cols = 0;
85+
86+
// get outputs' cols
87+
std::vector<int64_t> output_cols(outputs.size());
88+
for (int i = 0; i < num; ++i) {
89+
int t_cols = outputs[i].numel() / input_rows;
90+
input_cols += t_cols;
91+
output_cols[i] = t_cols;
92+
}
93+
auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
94+
95+
// computation
96+
for (int k = 0; k < input_rows; ++k) {
97+
const T* src_ptr = input.data<T>() + k * input_cols;
98+
int col_idx = 0;
99+
for (int j = 0; j < num; ++j) {
100+
int col_len = output_cols[j];
101+
T* dst_ptr = outputs[j].data<T>() + k * col_len;
102+
memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx,
103+
sizeof(T) * col_len);
104+
col_idx += col_len;
105+
}
106+
}
79107
}
80108
};
81109

@@ -84,6 +112,11 @@ template class ConcatFunctor<platform::CPUDeviceContext, int64_t>;
84112
template class ConcatFunctor<platform::CPUDeviceContext, float>;
85113
template class ConcatFunctor<platform::CPUDeviceContext, double>;
86114

115+
template class ConcatGradFunctor<platform::CPUDeviceContext, int>;
116+
template class ConcatGradFunctor<platform::CPUDeviceContext, int64_t>;
117+
template class ConcatGradFunctor<platform::CPUDeviceContext, float>;
118+
template class ConcatGradFunctor<platform::CPUDeviceContext, double>;
119+
87120
} // namespace math
88121
} // namespace operators
89122
} // namespace paddle

paddle/fluid/operators/math/concat.cu

Lines changed: 148 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace math {
2222
// TODO(zcd): This can be replaced by tensor,
2323
// if that, maybe we should add int8 to VarType::Type.
2424
// Or replaced by tensorArray.
25-
static constexpr int MaxSize = 32;
25+
static constexpr int MaxSize = 8;
2626
template <typename T>
2727
struct CUDADeviceArray {
2828
T data[MaxSize];
@@ -54,7 +54,6 @@ __global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
5454
const int output_rows, const int output_cols,
5555
T* output) {
5656
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
57-
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
5857
int segment = upper_bound<int>(input_cols.data, input_cols.size, tid_x) - 1;
5958

6059
int curr_offset = input_cols.data[segment];
@@ -69,31 +68,87 @@ __global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
6968
int local_col = tid_x - curr_offset;
7069
int segment_width = curr_col_offset - curr_offset;
7170
const T* input_ptr = inputs.data[curr_segment];
72-
71+
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
7372
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
7473
output[tid_y * output_cols + tid_x] =
7574
input_ptr[tid_y * segment_width + local_col];
7675
}
7776
}
7877

78+
template <typename T>
79+
__global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
80+
const int input_col, const int output_rows,
81+
const int output_cols, T* output) {
82+
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
83+
float inv_input_col = 1.0 / input_col;
84+
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
85+
int split = tid_x * inv_input_col;
86+
int in_offset = tid_x - split * input_col;
87+
const T* input_ptr = inputs.data[split];
88+
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
89+
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
90+
output[tid_y * output_cols + tid_x] =
91+
input_ptr[tid_y * input_col + in_offset];
92+
}
93+
}
94+
95+
template <typename T>
96+
__global__ void KernelConcatGrad(const T* input, const int input_row,
97+
const int input_col,
98+
CUDADeviceArray<int> output_cols,
99+
CUDADeviceArray<T*> outputs) {
100+
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
101+
int segment = upper_bound<int>(output_cols.data, output_cols.size, tid_x) - 1;
102+
int curr_offset = output_cols.data[segment];
103+
int curr_segment = segment;
104+
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
105+
T curr_col_offset;
106+
while ((curr_col_offset = output_cols.data[curr_segment + 1]) <= tid_x) {
107+
curr_offset = curr_col_offset;
108+
++curr_segment;
109+
}
110+
111+
int local_col = tid_x - curr_offset;
112+
int segment_width = curr_col_offset - curr_offset;
113+
T* output_ptr = outputs.data[curr_segment];
114+
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
115+
for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
116+
output_ptr[tid_y * segment_width + local_col] =
117+
input[tid_y * input_col + tid_x];
118+
}
119+
}
120+
121+
template <typename T>
122+
__global__ void KernelConcatGrad(const T* input, const int input_row,
123+
const int input_col, const int output_cols,
124+
CUDADeviceArray<T*> outputs) {
125+
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
126+
float inv_input_col = 1.0 / input_col;
127+
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
128+
int split = tid_x * inv_input_col;
129+
int in_offset = tid_x - split * input_col;
130+
T* output_ptr = outputs.data[split];
131+
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
132+
for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
133+
output_ptr[tid_y * output_cols + in_offset] =
134+
input[tid_y * input_col + tid_x];
135+
}
136+
}
137+
79138
/*
80139
* All tensors' dimension should be the same.
81140
*/
82141
template <typename T>
83142
class ConcatFunctor<platform::CUDADeviceContext, T> {
84143
public:
85144
void operator()(const platform::CUDADeviceContext& context,
86-
std::vector<framework::Tensor>& input, const int axis,
145+
const std::vector<framework::Tensor>& input, const int axis,
87146
framework::Tensor* output) {
88147
// assume the the max size of input is less than 8 and see the performance
89148
// save origin dim
90149
int num = input.size();
91-
// std::vector<paddle::framework::DDim> origin_dim(num);
92-
// for (int j = 0; j < num; ++j) {
93-
// origin_dim[j] = input[j].dims();
94-
// }
95-
auto out_dim = output->dims();
96-
150+
PADDLE_ENFORCE_LT(num, MaxSize, "input number should be less than %d",
151+
MaxSize);
97152
// get the matrix size
98153
int rows = 1;
99154
auto dim_0 = input[0].dims();
@@ -117,30 +172,96 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
117172
if (t_cols != cols) sameShape = false;
118173
}
119174
out_cols += t_cols;
120-
input[i].Resize({rows, t_cols});
121175
inputs_cols.data[i + 1] = out_cols;
122176
inputs_data.data[i] = input[i].data<T>();
123177
}
124-
output->Resize({out_rows, out_cols});
125178

126179
// computation
127-
const int kThreadsPerBlock = 256;
180+
// set the thread block and grid according to CurrentDeviceId
181+
const int kThreadsPerBlock = 1024;
128182
int block_cols = std::min(out_cols, kThreadsPerBlock);
129183
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
130184
dim3 block_size = dim3(block_cols, block_rows, 1);
131185

132-
int grid_cols = (out_cols + block_cols - 1) / block_cols;
133-
int grid_rows = (out_rows + block_rows - 1) / block_rows;
186+
int dev_id = paddle::platform::GetCurrentDeviceId();
187+
int multi_process = paddle::platform::GetCUDAMultiProcessors(dev_id);
188+
int max_threads_per_mp =
189+
paddle::platform::GetCUDAMaxThreadsPerMultiProcessor(dev_id);
190+
int max_threads = multi_process * max_threads_per_mp;
191+
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
192+
193+
int grid_cols =
194+
std::min((out_cols + block_cols - 1) / block_cols, max_blocks);
195+
int grid_rows =
196+
std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1));
134197
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
135198

136-
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
137-
inputs_data, inputs_cols, out_rows, out_cols, output->data<T>());
199+
if (sameShape) {
200+
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
201+
inputs_data, cols, out_rows, out_cols, output->data<T>());
202+
} else {
203+
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
204+
inputs_data, inputs_cols, out_rows, out_cols, output->data<T>());
205+
}
206+
}
207+
};
208+
209+
template <typename T>
210+
class ConcatGradFunctor<platform::CUDADeviceContext, T> {
211+
public:
212+
void operator()(const platform::CUDADeviceContext& context,
213+
const framework::Tensor& input, const int axis,
214+
std::vector<framework::Tensor>& outputs) {
215+
// assume the the max size of input is less than 8 and see the performance
216+
// save origin dim
217+
int num = outputs.size();
218+
PADDLE_ENFORCE_LT(num, MaxSize, "input number should be less than %d",
219+
MaxSize);
220+
221+
// get the matrix size
222+
int input_row = 1;
223+
auto dim_0 = outputs[0].dims();
224+
for (int i = 0; i < axis; ++i) {
225+
input_row *= dim_0[i];
226+
}
227+
228+
int output_col_0 = outputs[0].numel() / input_row;
229+
int input_col = 0;
230+
bool sameShape = true;
231+
232+
CUDADeviceArray<T*> outputs_data;
233+
CUDADeviceArray<int> outputs_cols;
234+
outputs_data.size = num;
235+
outputs_cols.size = num + 1;
236+
outputs_cols.data[0] = 0;
138237

139-
// recover origin dim
140-
// for (int j = 0; j < num; ++j) {
141-
// input[j].Resize(origin_dim[j]);
142-
// }
143-
output->Resize(out_dim);
238+
for (int i = 0; i < num; ++i) {
239+
int t_col = outputs[i].numel() / input_row;
240+
if (sameShape) {
241+
if (t_col != output_col_0) sameShape = false;
242+
}
243+
input_col += t_col;
244+
outputs_cols.data[i + 1] = input_col;
245+
outputs_data.data[i] = outputs[i].data<T>();
246+
}
247+
248+
// computation
249+
const int kThreadsPerBlock = 256;
250+
int block_cols = std::min(input_col, kThreadsPerBlock);
251+
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
252+
dim3 block_size = dim3(block_cols, block_rows, 1);
253+
254+
int grid_cols = (input_col + block_cols - 1) / block_cols;
255+
int grid_rows = (input_row + block_rows - 1) / block_rows;
256+
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
257+
258+
if (sameShape) {
259+
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
260+
input.data<T>(), input_row, input_col, output_col_0, outputs_data);
261+
} else {
262+
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
263+
input.data<T>(), input_row, input_col, outputs_cols, outputs_data);
264+
}
144265
}
145266
};
146267

@@ -149,6 +270,11 @@ template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
149270
template class ConcatFunctor<platform::CUDADeviceContext, float>;
150271
template class ConcatFunctor<platform::CUDADeviceContext, double>;
151272

273+
template class ConcatGradFunctor<platform::CUDADeviceContext, int>;
274+
template class ConcatGradFunctor<platform::CUDADeviceContext, int64_t>;
275+
template class ConcatGradFunctor<platform::CUDADeviceContext, float>;
276+
template class ConcatGradFunctor<platform::CUDADeviceContext, double>;
277+
152278
} // namespace math
153279
} // namespace operators
154280
} // namespace paddle

0 commit comments

Comments
 (0)