13
13
limitations under the License. */
14
14
15
15
#include " paddle/operators/conv_shift_op.h"
16
+ #include " paddle/operators/math/math_function.h"
16
17
#include " paddle/platform/cuda_helper.h"
17
18
18
19
namespace paddle {
@@ -22,7 +23,7 @@ using framework::Tensor;
22
23
23
24
namespace {
24
25
25
- inline int div_up (int x, int y) { return (x + y - 1 ) / y; }
26
+ inline int DivUp (int x, int y) { return (x + y - 1 ) / y; }
26
27
27
28
// Some notes on the design:
28
29
//
@@ -33,9 +34,9 @@ inline int div_up(int x, int y) { return (x + y - 1) / y; }
33
34
// y is fairly small. For large y, it would probably be more efficient
34
35
// to also tile across y.
35
36
template <typename T>
36
- __global__ void conv_shift_forward (const T *x, const T *y, T *out , int x_width,
37
- int y_width, int y_half_width,
38
- int batch_size ) {
37
+ __global__ void ConvShiftForward (const T *x, const T *y, int x_width,
38
+ int y_width, int y_half_width, int batch_size ,
39
+ T *out ) {
39
40
extern __shared__ T mem[];
40
41
41
42
int tx = threadIdx .x ;
@@ -62,25 +63,26 @@ __global__ void conv_shift_forward(const T *x, const T *y, T *out, int x_width,
62
63
if (tx < num_x) {
63
64
int load_i = (i - y_half_width + x_width) % x_width;
64
65
sx[tx] = x[k * x_width + load_i];
65
- } else {
66
- return ;
67
66
}
68
67
__syncthreads ();
69
68
70
- // Compute dot product of sx[tx:tx + y_width] and sy.
71
- T sum = 0 ;
72
- for (int j = 0 ; j < y_width; ++j) {
73
- sum += sx[tx + j] * sy[j];
74
- }
69
+ if (tx < num_x) {
70
+ // Compute dot product of sx[tx:tx + y_width] and sy.
71
+ T sum = 0 ;
72
+ for (int j = 0 ; j < y_width; ++j) {
73
+ sum += sx[tx + j] * sy[j];
74
+ }
75
75
76
- // Save to out[k, i].
77
- out[k * x_width + i] = sum;
76
+ // Save to out[k, i].
77
+ out[k * x_width + i] = sum;
78
+ }
78
79
}
79
80
80
81
// Compute x gradient - initial naive implementation with atomic add.
81
82
template <typename T>
82
- __global__ void conv_shift_dx (const T *dout, const T *y, T *dx, int x_width,
83
- int y_width, int y_half_width, int batch_size) {
83
+ __global__ void ConvShiftGradX (const T *dout, const T *y, int x_width,
84
+ int y_width, int y_half_width, int batch_size,
85
+ T *dx) {
84
86
int i = blockIdx .x * blockDim .x + threadIdx .x ; // x index
85
87
int j = blockIdx .y ; // y index
86
88
int k = blockIdx .z ; // batch index
@@ -94,8 +96,8 @@ __global__ void conv_shift_dx(const T *dout, const T *y, T *dx, int x_width,
94
96
95
97
// Compute y gradient - initial naive implementation with atomic add.
96
98
template <typename T>
97
- __global__ void conv_shift_dy (const T *x, const T *dout, T *dy , int x_width ,
98
- int y_width , int y_half_width, int batch_size ) {
99
+ __global__ void ConvShiftDy (const T *x, const T *dout, int x_width , int y_width ,
100
+ int y_half_width , int batch_size, T *dy ) {
99
101
int i = blockIdx .x * blockDim .x + threadIdx .x ; // x index
100
102
int j = blockIdx .y ; // y index
101
103
int k = blockIdx .z ; // batch index
@@ -125,15 +127,15 @@ class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
125
127
int y_half_width = (y_width - 1 ) / 2 ;
126
128
127
129
const int x_per_block = 256 ;
128
- int num_x_blocks = div_up (x_width, x_per_block);
130
+ int num_x_blocks = DivUp (x_width, x_per_block);
129
131
int mem_per_block = (x_per_block + 2 * y_width) * sizeof (T);
130
132
131
133
dim3 grid_dim (num_x_blocks, batch_size);
132
134
133
135
auto stream = context.cuda_device_context ().stream ();
134
136
135
- conv_shift_forward <T><<<grid_dim, x_per_block, mem_per_block, stream>>> (
136
- x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size);
137
+ ConvShiftForward <T><<<grid_dim, x_per_block, mem_per_block, stream>>> (
138
+ x_data, y_data, x_width, y_width, y_half_width, batch_size, out_data );
137
139
}
138
140
};
139
141
@@ -157,25 +159,26 @@ class ConvShiftGradKernel<platform::GPUPlace, T>
157
159
int y_width = Y->dims ()[1 ];
158
160
int y_half_width = (y_width - 1 ) / 2 ;
159
161
160
- auto stream = context.cuda_device_context ().stream ();
162
+ auto &device_ctx = context.cuda_device_context ();
163
+ math::SetConstant<platform::GPUPlace, T> zero;
161
164
162
165
const int x_per_block = 256 ;
163
- int num_x_blocks = div_up (x_width, x_per_block);
166
+ int num_x_blocks = DivUp (x_width, x_per_block);
164
167
dim3 grid_dim (num_x_blocks, y_width, batch_size);
165
168
166
169
if (dX) {
167
170
T *dx_data = dX->mutable_data <T>(context.GetPlace ());
168
- cudaMemsetAsync (dx_data, 0 , dX-> numel () * sizeof (T), stream );
169
- conv_shift_dx <T><<<grid_dim, x_per_block, 0 , stream>>> (
170
- dout_data, y_data, dx_data, x_width, y_width, y_half_width,
171
- batch_size );
171
+ zero (device_ctx, dX, static_cast <T>( 0.0 ) );
172
+ ConvShiftGradX <T><<<grid_dim, x_per_block, 0 , device_ctx. stream() >>> (
173
+ dout_data, y_data, x_width, y_width, y_half_width, batch_size ,
174
+ dx_data );
172
175
}
173
176
if (dY) {
174
177
T *dy_data = dY->mutable_data <T>(context.GetPlace ());
175
- cudaMemsetAsync (dy_data, 0 , dY-> numel () * sizeof (T), stream );
176
- conv_shift_dy <T><<<grid_dim, x_per_block, 0 , stream>>> (
177
- x_data, dout_data, dy_data, x_width, y_width, y_half_width,
178
- batch_size );
178
+ zero (device_ctx, dY, static_cast <T>( 0.0 ) );
179
+ ConvShiftDy <T><<<grid_dim, x_per_block, 0 , device_ctx. stream() >>> (
180
+ x_data, dout_data, x_width, y_width, y_half_width, batch_size ,
181
+ dy_data );
179
182
}
180
183
}
181
184
};
0 commit comments