@@ -22,7 +22,7 @@ using framework::Tensor;
22
22
23
23
namespace {
24
24
25
- inline int div_up (int x, int y) { return (x + y - 1 ) / y; }
25
+ inline int DivUp (int x, int y) { return (x + y - 1 ) / y; }
26
26
27
27
// Some notes on the design:
28
28
//
@@ -33,9 +33,9 @@ inline int div_up(int x, int y) { return (x + y - 1) / y; }
33
33
// y is fairly small. For large y, it would probably be more efficient
34
34
// to also tile across y.
35
35
template <typename T>
36
- __global__ void conv_shift_forward (const T *x, const T *y, T *out, int x_width,
37
- int y_width, int y_half_width,
38
- int batch_size) {
36
+ __global__ void ConvShiftForward (const T *x, const T *y, T *out, int x_width,
37
+ int y_width, int y_half_width,
38
+ int batch_size) {
39
39
extern __shared__ T mem[];
40
40
41
41
int tx = threadIdx .x ;
@@ -79,8 +79,8 @@ __global__ void conv_shift_forward(const T *x, const T *y, T *out, int x_width,
79
79
80
80
// Compute x gradient - initial naive implementation with atomic add.
81
81
template <typename T>
82
- __global__ void conv_shift_dx (const T *dout, const T *y, T *dx, int x_width,
83
- int y_width, int y_half_width, int batch_size) {
82
+ __global__ void ConvShiftGradX (const T *dout, const T *y, T *dx, int x_width,
83
+ int y_width, int y_half_width, int batch_size) {
84
84
int i = blockIdx .x * blockDim .x + threadIdx .x ; // x index
85
85
int j = blockIdx .y ; // y index
86
86
int k = blockIdx .z ; // batch index
@@ -94,8 +94,8 @@ __global__ void conv_shift_dx(const T *dout, const T *y, T *dx, int x_width,
94
94
95
95
// Compute y gradient - initial naive implementation with atomic add.
96
96
template <typename T>
97
- __global__ void conv_shift_dy (const T *x, const T *dout, T *dy, int x_width,
98
- int y_width, int y_half_width, int batch_size) {
97
+ __global__ void ConvShiftDy (const T *x, const T *dout, T *dy, int x_width,
98
+ int y_width, int y_half_width, int batch_size) {
99
99
int i = blockIdx .x * blockDim .x + threadIdx .x ; // x index
100
100
int j = blockIdx .y ; // y index
101
101
int k = blockIdx .z ; // batch index
@@ -125,14 +125,14 @@ class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
125
125
int y_half_width = (y_width - 1 ) / 2 ;
126
126
127
127
const int x_per_block = 256 ;
128
- int num_x_blocks = div_up (x_width, x_per_block);
128
+ int num_x_blocks = DivUp (x_width, x_per_block);
129
129
int mem_per_block = (x_per_block + 2 * y_width) * sizeof (T);
130
130
131
131
dim3 grid_dim (num_x_blocks, batch_size);
132
132
133
133
auto stream = context.cuda_device_context ().stream ();
134
134
135
- conv_shift_forward <T><<<grid_dim, x_per_block, mem_per_block, stream>>> (
135
+ ConvShiftForward <T><<<grid_dim, x_per_block, mem_per_block, stream>>> (
136
136
x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size);
137
137
}
138
138
};
@@ -160,20 +160,20 @@ class ConvShiftGradKernel<platform::GPUPlace, T>
160
160
auto stream = context.cuda_device_context ().stream ();
161
161
162
162
const int x_per_block = 256 ;
163
- int num_x_blocks = div_up (x_width, x_per_block);
163
+ int num_x_blocks = DivUp (x_width, x_per_block);
164
164
dim3 grid_dim (num_x_blocks, y_width, batch_size);
165
165
166
166
if (dX) {
167
167
T *dx_data = dX->mutable_data <T>(context.GetPlace ());
168
168
cudaMemsetAsync (dx_data, 0 , dX->numel () * sizeof (T), stream);
169
- conv_shift_dx <T><<<grid_dim, x_per_block, 0 , stream>>> (
169
+ ConvShiftGradX <T><<<grid_dim, x_per_block, 0 , stream>>> (
170
170
dout_data, y_data, dx_data, x_width, y_width, y_half_width,
171
171
batch_size);
172
172
}
173
173
if (dY) {
174
174
T *dy_data = dY->mutable_data <T>(context.GetPlace ());
175
175
cudaMemsetAsync (dy_data, 0 , dY->numel () * sizeof (T), stream);
176
- conv_shift_dy <T><<<grid_dim, x_per_block, 0 , stream>>> (
176
+ ConvShiftDy <T><<<grid_dim, x_per_block, 0 , stream>>> (
177
177
x_data, dout_data, dy_data, x_width, y_width, y_half_width,
178
178
batch_size);
179
179
}
0 commit comments