Skip to content

Commit d0b601c

Browse files
author
Markus Kliegl
committed
address PR feedback
1 parent 42dd5da commit d0b601c

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

paddle/operators/conv_shift_op.cu

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
limitations under the License. */
1414

1515
#include "paddle/operators/conv_shift_op.h"
16+
#include "paddle/operators/math/math_function.h"
1617
#include "paddle/platform/cuda_helper.h"
1718

1819
namespace paddle {
@@ -33,9 +34,9 @@ inline int DivUp(int x, int y) { return (x + y - 1) / y; }
3334
// y is fairly small. For large y, it would probably be more efficient
3435
// to also tile across y.
3536
template <typename T>
36-
__global__ void ConvShiftForward(const T *x, const T *y, T *out, int x_width,
37-
int y_width, int y_half_width,
38-
int batch_size) {
37+
__global__ void ConvShiftForward(const T *x, const T *y, int x_width,
38+
int y_width, int y_half_width, int batch_size,
39+
T *out) {
3940
extern __shared__ T mem[];
4041

4142
int tx = threadIdx.x;
@@ -79,8 +80,9 @@ __global__ void ConvShiftForward(const T *x, const T *y, T *out, int x_width,
7980

8081
// Compute x gradient - initial naive implementation with atomic add.
8182
template <typename T>
82-
__global__ void ConvShiftGradX(const T *dout, const T *y, T *dx, int x_width,
83-
int y_width, int y_half_width, int batch_size) {
83+
__global__ void ConvShiftGradX(const T *dout, const T *y, int x_width,
84+
int y_width, int y_half_width, int batch_size,
85+
T *dx) {
8486
int i = blockIdx.x * blockDim.x + threadIdx.x; // x index
8587
int j = blockIdx.y; // y index
8688
int k = blockIdx.z; // batch index
@@ -94,8 +96,8 @@ __global__ void ConvShiftGradX(const T *dout, const T *y, T *dx, int x_width,
9496

9597
// Compute y gradient - initial naive implementation with atomic add.
9698
template <typename T>
97-
__global__ void ConvShiftDy(const T *x, const T *dout, T *dy, int x_width,
98-
int y_width, int y_half_width, int batch_size) {
99+
__global__ void ConvShiftDy(const T *x, const T *dout, int x_width, int y_width,
100+
int y_half_width, int batch_size, T *dy) {
99101
int i = blockIdx.x * blockDim.x + threadIdx.x; // x index
100102
int j = blockIdx.y; // y index
101103
int k = blockIdx.z; // batch index
@@ -133,7 +135,7 @@ class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
133135
auto stream = context.cuda_device_context().stream();
134136

135137
ConvShiftForward<T><<<grid_dim, x_per_block, mem_per_block, stream>>>(
136-
x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size);
138+
x_data, y_data, x_width, y_width, y_half_width, batch_size, out_data);
137139
}
138140
};
139141

@@ -157,25 +159,26 @@ class ConvShiftGradKernel<platform::GPUPlace, T>
157159
int y_width = Y->dims()[1];
158160
int y_half_width = (y_width - 1) / 2;
159161

160-
auto stream = context.cuda_device_context().stream();
162+
auto &device_ctx = context.cuda_device_context();
163+
math::SetConstant<platform::GPUPlace, T> zero;
161164

162165
const int x_per_block = 256;
163166
int num_x_blocks = DivUp(x_width, x_per_block);
164167
dim3 grid_dim(num_x_blocks, y_width, batch_size);
165168

166169
if (dX) {
167170
T *dx_data = dX->mutable_data<T>(context.GetPlace());
168-
cudaMemsetAsync(dx_data, 0, dX->numel() * sizeof(T), stream);
169-
ConvShiftGradX<T><<<grid_dim, x_per_block, 0, stream>>>(
170-
dout_data, y_data, dx_data, x_width, y_width, y_half_width,
171-
batch_size);
171+
zero(device_ctx, dX, static_cast<T>(0.0));
172+
ConvShiftGradX<T><<<grid_dim, x_per_block, 0, device_ctx.stream()>>>(
173+
dout_data, y_data, x_width, y_width, y_half_width, batch_size,
174+
dx_data);
172175
}
173176
if (dY) {
174177
T *dy_data = dY->mutable_data<T>(context.GetPlace());
175-
cudaMemsetAsync(dy_data, 0, dY->numel() * sizeof(T), stream);
176-
ConvShiftDy<T><<<grid_dim, x_per_block, 0, stream>>>(
177-
x_data, dout_data, dy_data, x_width, y_width, y_half_width,
178-
batch_size);
178+
zero(device_ctx, dY, static_cast<T>(0.0));
179+
ConvShiftDy<T><<<grid_dim, x_per_block, 0, device_ctx.stream()>>>(
180+
x_data, dout_data, x_width, y_width, y_half_width, batch_size,
181+
dy_data);
179182
}
180183
}
181184
};

0 commit comments

Comments
 (0)