Skip to content

Commit a270fdf

Browse files
author
chengduo
authored
Fix SelectedRowsAdd bug (#14309)
* fix selected_rows bug test=develop * refine cos_sim test=develop
1 parent 1001f8e commit a270fdf

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

paddle/fluid/operators/math/cos_sim_functor.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ struct CosSimDyFunctor<platform::CUDADeviceContext, T> {
5151
T* dy) const {
5252
const int block_size = 512;
5353
dim3 threads(block_size, 1);
54-
dim3 grid(1, (rows + block_size - 1) / block_size);
54+
dim3 grid((rows + block_size - 1) / block_size, 1);
5555
CosSimDyKernel<T><<<grid, threads, 0, ctx.stream()>>>(
5656
x_norm, y_norm, x, y, z, dz, rows, cols, dy);
5757
}

paddle/fluid/operators/math/selected_rows_functor.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ template <typename T, int block_size>
8181
__global__ void SelectedRowsAddTensorKernel(const T* selected_rows,
8282
const int64_t* rows, T* tensor_out,
8383
int64_t row_numel) {
84-
const int ty = blockIdx.y;
84+
const int ty = blockIdx.x;
8585
int tid = threadIdx.x;
8686

8787
selected_rows += ty * row_numel;
@@ -123,7 +123,7 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
123123

124124
const int block_size = 256;
125125
dim3 threads(block_size, 1);
126-
dim3 grid(1, in1_rows.size());
126+
dim3 grid(in1_rows.size(), 1);
127127
SelectedRowsAddTensorKernel<
128128
T, block_size><<<grid, threads, 0, context.stream()>>>(
129129
in1_data, in1_rows.CUDAData(context.GetPlace()), out_data,
@@ -188,7 +188,7 @@ __global__ void SelectedRowsAddToTensorKernel(const T* selected_rows,
188188
const int64_t* rows,
189189
T* tensor_out,
190190
int64_t row_numel) {
191-
const int ty = blockIdx.y;
191+
const int ty = blockIdx.x;
192192
int tid = threadIdx.x;
193193

194194
selected_rows += ty * row_numel;
@@ -221,7 +221,7 @@ struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
221221
auto* in2_data = input2->data<T>();
222222
const int block_size = 256;
223223
dim3 threads(block_size, 1);
224-
dim3 grid(1, in1_rows.size());
224+
dim3 grid(in1_rows.size(), 1);
225225
SelectedRowsAddToTensorKernel<
226226
T, block_size><<<grid, threads, 0, context.stream()>>>(
227227
in1_data, in1_rows.CUDAData(context.GetPlace()), in2_data,
@@ -388,7 +388,7 @@ template <typename T, int block_size>
388388
__global__ void UpdateToTensorKernel(const T* selected_rows,
389389
const int64_t* rows, const ScatterOps& op,
390390
T* tensor_out, int64_t row_numel) {
391-
const int ty = blockIdx.y;
391+
const int ty = blockIdx.x;
392392
int tid = threadIdx.x;
393393

394394
selected_rows += ty * row_numel;
@@ -457,7 +457,7 @@ struct UpdateToTensor<platform::CUDADeviceContext, T> {
457457
auto* in2_data = input2->data<T>();
458458

459459
dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1);
460-
dim3 grid(1, in1_rows.size());
460+
dim3 grid(in1_rows.size(), 1);
461461
UpdateToTensorKernel<T, platform::PADDLE_CUDA_NUM_THREADS><<<
462462
grid, threads, 0, context.stream()>>>(in1_data, in1_rows.cuda_data(),
463463
op, in2_data, in1_row_numel);

0 commit comments

Comments
 (0)