Skip to content

Commit 131ec27

Browse files
committed
fix bug for big number; float->double and code refine
1 parent 82bd82c commit 131ec27

File tree

3 files changed

+36
-17
lines changed

3 files changed

+36
-17
lines changed

paddle/fluid/operators/math/concat.cu

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ __global__ void KernelConcat(T** inputs, const int input_col,
7070
const int output_rows, const int output_cols,
7171
T* output) {
7272
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
73-
float inv_input_col = 1.0 / input_col;
73+
double inv_input_col = 1.0 / input_col;
7474
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
7575
int split = tid_x * inv_input_col;
7676
int in_offset = tid_x - split * input_col;
@@ -113,7 +113,7 @@ __global__ void KernelConcatGrad(const T* input, const int input_row,
113113
const int input_col, const int output_cols,
114114
T** outputs) {
115115
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
116-
float inv_input_col = 1.0 / input_col;
116+
double inv_input_col = 1.0 / input_col;
117117
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
118118
int split = tid_x * inv_input_col;
119119
int in_offset = tid_x - split * input_col;
@@ -145,8 +145,8 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
145145
int cols = input[0].numel() / rows;
146146
int out_rows = rows, out_cols = 0;
147147

148-
paddle::framework::Vector<int16_t> inputs_data(num * sizeof(T*) / 2);
149-
paddle::framework::Vector<int> inputs_cols(num + 1);
148+
framework::Vector<int16_t> inputs_data(num * sizeof(T*) / 2);
149+
framework::Vector<int> inputs_cols(num + 1);
150150
inputs_cols[0] = 0;
151151
T** inputs_ptr = reinterpret_cast<T**>(inputs_data.data());
152152

@@ -168,15 +168,14 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
168168
// computation
169169
// set the thread block and grid according to CurrentDeviceId
170170
const int kThreadsPerBlock = 1024;
171-
int block_cols = std::min(out_cols, kThreadsPerBlock);
172-
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
171+
int block_cols = kThreadsPerBlock;
172+
if (out_cols < kThreadsPerBlock) { // block_cols is aligned by 32.
173+
block_cols = ((out_cols + 31) >> 5) << 5;
174+
}
175+
int block_rows = kThreadsPerBlock / block_cols;
173176
dim3 block_size = dim3(block_cols, block_rows, 1);
174177

175-
int dev_id = paddle::platform::GetCurrentDeviceId();
176-
int multi_process = paddle::platform::GetCUDAMultiProcessors(dev_id);
177-
int max_threads_per_mp =
178-
paddle::platform::GetCUDAMaxThreadsPerMultiProcessor(dev_id);
179-
int max_threads = multi_process * max_threads_per_mp;
178+
int max_threads = context.GetMaxPhysicalThreadCount();
180179
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
181180

182181
int grid_cols =
@@ -218,8 +217,8 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
218217
int input_col = 0;
219218
bool sameShape = true;
220219

221-
paddle::framework::Vector<int16_t> outputs_data(num * sizeof(T*) / 2);
222-
paddle::framework::Vector<int> outputs_cols(num + 1);
220+
framework::Vector<int16_t> outputs_data(num * sizeof(T*) / 2);
221+
framework::Vector<int> outputs_cols(num + 1);
223222
outputs_cols[0] = 0;
224223
T** outputs_ptr = reinterpret_cast<T**>(outputs_data.data());
225224

@@ -239,12 +238,20 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
239238

240239
// computation
241240
const int kThreadsPerBlock = 1024;
242-
int block_cols = std::min(input_col, kThreadsPerBlock);
243-
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
241+
int block_cols = kThreadsPerBlock;
242+
if (input_col < kThreadsPerBlock) { // block_cols is aligned by 32.
243+
block_cols = ((input_col + 31) >> 5) << 5;
244+
}
245+
int block_rows = kThreadsPerBlock / block_cols;
244246
dim3 block_size = dim3(block_cols, block_rows, 1);
245247

246-
int grid_cols = (input_col + block_cols - 1) / block_cols;
247-
int grid_rows = (input_row + block_rows - 1) / block_rows;
248+
int max_threads = context.GetMaxPhysicalThreadCount();
249+
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
250+
251+
int grid_cols =
252+
std::min((input_col + block_cols - 1) / block_cols, max_blocks);
253+
int grid_rows =
254+
std::min(max_blocks / grid_cols, std::max(input_row / block_rows, 1));
248255
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
249256

250257
if (sameShape) {

paddle/fluid/platform/device_context.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
121121

122122
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
123123
SetDeviceId(place_.device);
124+
multi_process = GetCUDAMultiProcessors(place_.device);
125+
max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
124126
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
125127
eigen_stream_.reset(new EigenCudaStreamDevice());
126128
eigen_stream_->Reinitialize(&stream_, place);
@@ -154,6 +156,10 @@ void CUDADeviceContext::Wait() const {
154156
PADDLE_ENFORCE(cudaGetLastError());
155157
}
156158

159+
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
160+
return multi_process * max_threads_per_mp;
161+
}
162+
157163
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
158164
return eigen_device_.get();
159165
}

paddle/fluid/platform/device_context.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ class CUDADeviceContext : public DeviceContext {
7979
/*! \brief Return place in the device context. */
8080
Place GetPlace() const override;
8181

82+
/*! \brief Return the max physical thread count in the device context */
83+
int GetMaxPhysicalThreadCount() const;
84+
8285
/*! \brief Return eigen device in the device context. */
8386
Eigen::GpuDevice* eigen_device() const;
8487

@@ -100,6 +103,9 @@ class CUDADeviceContext : public DeviceContext {
100103
cudaStream_t stream_;
101104
cudnnHandle_t cudnn_handle_;
102105
cublasHandle_t cublas_handle_;
106+
107+
int multi_process;
108+
int max_threads_per_mp;
103109
};
104110

105111
template <>

0 commit comments

Comments
 (0)