13
13
limitations under the License. */
14
14
15
15
#include " paddle/operators/conv_shift_op.h"
16
+ #include " paddle/operators/math/math_function.h"
16
17
#include " paddle/platform/cuda_helper.h"
17
18
18
19
namespace paddle {
@@ -33,9 +34,9 @@ inline int DivUp(int x, int y) { return (x + y - 1) / y; }
33
34
// y is fairly small. For large y, it would probably be more efficient
34
35
// to also tile across y.
35
36
template <typename T>
36
- __global__ void ConvShiftForward (const T *x, const T *y, T *out, int x_width,
37
- int y_width, int y_half_width,
38
- int batch_size ) {
37
+ __global__ void ConvShiftForward (const T *x, const T *y, int x_width,
38
+ int y_width, int y_half_width, int batch_size,
39
+ T *out ) {
39
40
extern __shared__ T mem[];
40
41
41
42
int tx = threadIdx .x ;
@@ -79,8 +80,9 @@ __global__ void ConvShiftForward(const T *x, const T *y, T *out, int x_width,
79
80
80
81
// Compute x gradient - initial naive implementation with atomic add.
81
82
template <typename T>
82
- __global__ void ConvShiftGradX (const T *dout, const T *y, T *dx, int x_width,
83
- int y_width, int y_half_width, int batch_size) {
83
+ __global__ void ConvShiftGradX (const T *dout, const T *y, int x_width,
84
+ int y_width, int y_half_width, int batch_size,
85
+ T *dx) {
84
86
int i = blockIdx .x * blockDim .x + threadIdx .x ; // x index
85
87
int j = blockIdx .y ; // y index
86
88
int k = blockIdx .z ; // batch index
@@ -94,8 +96,8 @@ __global__ void ConvShiftGradX(const T *dout, const T *y, T *dx, int x_width,
94
96
95
97
// Compute y gradient - initial naive implementation with atomic add.
96
98
template <typename T>
97
- __global__ void ConvShiftDy (const T *x, const T *dout, T *dy , int x_width ,
98
- int y_width , int y_half_width, int batch_size ) {
99
+ __global__ void ConvShiftDy (const T *x, const T *dout, int x_width , int y_width ,
100
+ int y_half_width , int batch_size, T *dy ) {
99
101
int i = blockIdx .x * blockDim .x + threadIdx .x ; // x index
100
102
int j = blockIdx .y ; // y index
101
103
int k = blockIdx .z ; // batch index
@@ -133,7 +135,7 @@ class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
133
135
auto stream = context.cuda_device_context ().stream ();
134
136
135
137
ConvShiftForward<T><<<grid_dim, x_per_block, mem_per_block, stream>>> (
136
- x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size);
138
+ x_data, y_data, x_width, y_width, y_half_width, batch_size, out_data );
137
139
}
138
140
};
139
141
@@ -157,25 +159,26 @@ class ConvShiftGradKernel<platform::GPUPlace, T>
157
159
int y_width = Y->dims ()[1 ];
158
160
int y_half_width = (y_width - 1 ) / 2 ;
159
161
160
- auto stream = context.cuda_device_context ().stream ();
162
+ auto &device_ctx = context.cuda_device_context ();
163
+ math::SetConstant<platform::GPUPlace, T> zero;
161
164
162
165
const int x_per_block = 256 ;
163
166
int num_x_blocks = DivUp (x_width, x_per_block);
164
167
dim3 grid_dim (num_x_blocks, y_width, batch_size);
165
168
166
169
if (dX) {
167
170
T *dx_data = dX->mutable_data <T>(context.GetPlace ());
168
- cudaMemsetAsync (dx_data, 0 , dX-> numel () * sizeof (T), stream );
169
- ConvShiftGradX<T><<<grid_dim, x_per_block, 0 , stream>>> (
170
- dout_data, y_data, dx_data, x_width, y_width, y_half_width,
171
- batch_size );
171
+ zero (device_ctx, dX, static_cast <T>( 0.0 ) );
172
+ ConvShiftGradX<T><<<grid_dim, x_per_block, 0 , device_ctx. stream() >>> (
173
+ dout_data, y_data, x_width, y_width, y_half_width, batch_size ,
174
+ dx_data );
172
175
}
173
176
if (dY) {
174
177
T *dy_data = dY->mutable_data <T>(context.GetPlace ());
175
- cudaMemsetAsync (dy_data, 0 , dY-> numel () * sizeof (T), stream );
176
- ConvShiftDy<T><<<grid_dim, x_per_block, 0 , stream>>> (
177
- x_data, dout_data, dy_data, x_width, y_width, y_half_width,
178
- batch_size );
178
+ zero (device_ctx, dY, static_cast <T>( 0.0 ) );
179
+ ConvShiftDy<T><<<grid_dim, x_per_block, 0 , device_ctx. stream() >>> (
180
+ x_data, dout_data, x_width, y_width, y_half_width, batch_size ,
181
+ dy_data );
179
182
}
180
183
}
181
184
};
0 commit comments