Skip to content

Commit 43a3af8

Browse files
author
chengduo
authored
refine sgd_op (#13626)
test=develop
1 parent adae0a3 commit 43a3af8

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

paddle/fluid/operators/sgd_op.cu

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#define EIGEN_USE_GPU
15+
#include <algorithm>
1616
#include "paddle/fluid/operators/sgd_op.h"
1717
#include "paddle/fluid/platform/cuda_primitives.h"
1818

@@ -33,22 +33,21 @@ __global__ void SGDKernel(const T* g, const T* p, const T* learning_rate,
3333
}
3434
}
3535

36-
template <typename T, int block_size>
36+
template <typename T>
3737
__global__ void SparseSGDFunctorKernel(const T* selected_rows,
3838
const int64_t* rows,
3939
const T* learning_rate, T* tensor_out,
40-
int64_t row_numel) {
41-
const int ty = blockIdx.y;
42-
int tid = threadIdx.x;
43-
44-
selected_rows += ty * row_numel;
45-
tensor_out += rows[ty] * row_numel;
46-
47-
for (int index = tid; index < row_numel; index += block_size) {
48-
// Since index in rows of SelectedRows can be duplicate, we have to use
49-
// Atomic Operation to avoid concurrent write error.
50-
paddle::platform::CudaAtomicAdd(
51-
tensor_out + index, -1.0 * learning_rate[0] * selected_rows[index]);
40+
int64_t row_numel, int64_t limit) {
41+
for (int64_t i = blockIdx.x; i < limit; i += gridDim.x) {
42+
const T* selected_rows_ptr = selected_rows + i * row_numel;
43+
T* tensor_out_ptr = tensor_out + rows[i] * row_numel;
44+
for (int64_t index = threadIdx.x; index < row_numel; index += blockDim.x) {
45+
// Since index in rows of SelectedRows can be duplicate, we have to use
46+
// Atomic Operation to avoid concurrent write error.
47+
paddle::platform::CudaAtomicAdd(
48+
tensor_out_ptr + index,
49+
-1.0 * learning_rate[0] * selected_rows_ptr[index]);
50+
}
5251
}
5352
}
5453
} // namespace
@@ -97,13 +96,15 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> {
9796
auto* in_data = in_value.data<T>();
9897
auto* out_data = param_out->data<T>();
9998

100-
const int block_size = 256;
101-
dim3 threads(block_size, 1);
102-
dim3 grid(1, in_rows.size());
103-
SparseSGDFunctorKernel<
104-
T, 256><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>(
99+
const int kThreadsPerBlock = 256;
100+
int thread_x = kThreadsPerBlock;
101+
int max_threads = ctx.cuda_device_context().GetMaxPhysicalThreadCount();
102+
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
103+
104+
SparseSGDFunctorKernel<<<max_blocks, thread_x, 0,
105+
ctx.cuda_device_context().stream()>>>(
105106
in_data, in_rows.CUDAData(ctx.GetPlace()), learning_rate->data<T>(),
106-
out_data, in_row_numel);
107+
out_data, in_row_numel, in_rows.size());
107108

108109
} else {
109110
PADDLE_THROW("Unsupported Variable Type of Grad");

0 commit comments

Comments
 (0)