@@ -23,7 +23,8 @@ __global__ void KeNearestNeighborInterpFw(
23
23
const T* in, const size_t in_img_h, const size_t in_img_w,
24
24
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
25
25
const size_t out_img_w, const size_t output_h, const size_t output_w,
26
- const size_t num_channels, const float ratio_h, const float ratio_w) {
26
+ const size_t num_channels, const float ratio_h, const float ratio_w,
27
+ const bool align_corners) {
27
28
int nthreads = output_h * output_w;
28
29
int tid = blockIdx .x * blockDim .x + threadIdx .x ;
29
30
int stride = blockDim .x * gridDim .x ;
@@ -35,10 +36,14 @@ __global__ void KeNearestNeighborInterpFw(
35
36
int channel_id = out_id_w / out_img_size;
36
37
37
38
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
38
- int in_img_idy = static_cast <int >(ratio_h * out_img_idy + 0.5 );
39
+ int in_img_idy = (align_corners)
40
+ ? static_cast <int >(ratio_h * out_img_idy + 0.5 )
41
+ : static_cast <int >(ratio_h * out_img_idy);
39
42
40
43
int out_img_idx = tid % out_img_w;
41
- int in_img_idx = static_cast <int >(ratio_w * out_img_idx + 0.5 );
44
+ int in_img_idx = (align_corners)
45
+ ? static_cast <int >(ratio_w * out_img_idx + 0.5 )
46
+ : static_cast <int >(ratio_w * out_img_idx);
42
47
43
48
out[tid] = in[out_id_h * input_w + channel_id * in_img_size +
44
49
in_img_idy * in_img_w + in_img_idx];
@@ -50,7 +55,8 @@ __global__ void KeNearestNeighborInterpBw(
50
55
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
51
56
const size_t input_w, const T* out, const size_t out_img_h,
52
57
const size_t out_img_w, const size_t output_h, const size_t output_w,
53
- const size_t num_channels, const float ratio_h, const float ratio_w) {
58
+ const size_t num_channels, const float ratio_h, const float ratio_w,
59
+ const bool align_corners) {
54
60
int nthreads = output_h * output_w;
55
61
int tid = blockIdx .x * blockDim .x + threadIdx .x ;
56
62
int stride = blockDim .x * gridDim .x ;
@@ -62,10 +68,14 @@ __global__ void KeNearestNeighborInterpBw(
62
68
int channel_id = out_id_w / out_img_size;
63
69
64
70
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
65
- int in_img_idy = static_cast <int >(ratio_h * out_img_idy + 0.5 );
71
+ int in_img_idy = (align_corners)
72
+ ? static_cast <int >(ratio_h * out_img_idy + 0.5 )
73
+ : static_cast <int >(ratio_h * out_img_idy);
66
74
67
75
int out_img_idx = tid % out_img_w;
68
- int in_img_idx = static_cast <int >(ratio_w * out_img_idx + 0.5 );
76
+ int in_img_idx = (align_corners)
77
+ ? static_cast <int >(ratio_w * out_img_idx + 0.5 )
78
+ : static_cast <int >(ratio_w * out_img_idx);
69
79
70
80
T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
71
81
in_img_idy * in_img_w + in_img_idx];
@@ -79,10 +89,12 @@ __global__ void KeBilinearInterpFw(
79
89
const T* in, const size_t in_img_h, const size_t in_img_w,
80
90
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
81
91
const size_t out_img_w, const size_t output_h, const size_t output_w,
82
- const size_t num_channels, const float ratio_h, const float ratio_w) {
92
+ const size_t num_channels, const float ratio_h, const float ratio_w,
93
+ const bool align_corners, const int align_mode) {
83
94
int nthreads = output_h * output_w;
84
95
int tid = blockIdx .x * blockDim .x + threadIdx .x ;
85
96
int stride = blockDim .x * gridDim .x ;
97
+ bool align_flag = (align_mode == 0 && !align_corners);
86
98
for (; tid < nthreads; tid += stride) {
87
99
int out_id_h = tid / output_w;
88
100
int out_id_w = tid % output_w;
@@ -91,15 +103,23 @@ __global__ void KeBilinearInterpFw(
91
103
int channel_id = out_id_w / out_img_size;
92
104
93
105
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
94
- int in_img_idy = ratio_h * out_img_idy;
106
+ int in_img_idy = align_flag
107
+ ? static_cast <int >(ratio_h * (out_img_idy + 0.5 ) - 0.5 )
108
+ : static_cast <int >(ratio_h * out_img_idy);
109
+ in_img_idy = (in_img_idy > 0 ) ? in_img_idy : 0 ;
95
110
int h_id = (in_img_idy < in_img_h - 1 ) ? 1 : 0 ;
96
- T h1lambda = ratio_h * out_img_idy - in_img_idy;
111
+ T h1lambda = align_flag ? ratio_h * (out_img_idy + 0.5 ) - 0.5 - in_img_idy
112
+ : ratio_h * out_img_idy - in_img_idy;
97
113
T h2lambda = 1 .f - h1lambda;
98
114
99
115
int out_img_idx = tid % out_img_w;
100
- int in_img_idx = ratio_w * out_img_idx;
116
+ int in_img_idx = align_flag
117
+ ? static_cast <int >(ratio_w * (out_img_idx + 0.5 ) - 0.5 )
118
+ : static_cast <int >(ratio_w * out_img_idx);
119
+ in_img_idx = (in_img_idx > 0 ) ? in_img_idx : 0 ;
101
120
int w_id = (in_img_idx < in_img_w - 1 ) ? 1 : 0 ;
102
- T w1lambda = ratio_w * out_img_idx - in_img_idx;
121
+ T w1lambda = align_flag ? ratio_w * (out_img_idx + 0.5 ) - 0.5 - in_img_idx
122
+ : ratio_w * out_img_idx - in_img_idx;
103
123
T w2lambda = 1 .f - w1lambda;
104
124
105
125
const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
@@ -118,10 +138,12 @@ __global__ void KeBilinearInterpBw(
118
138
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
119
139
const size_t input_w, const T* out, const size_t out_img_h,
120
140
const size_t out_img_w, const size_t output_h, const size_t output_w,
121
- const size_t num_channels, const T ratio_h, const T ratio_w) {
141
+ const size_t num_channels, const T ratio_h, const T ratio_w,
142
+ const bool align_corners, const int align_mode) {
122
143
int nthreads = output_h * output_w;
123
144
int tid = blockIdx .x * blockDim .x + threadIdx .x ;
124
145
int stride = blockDim .x * gridDim .x ;
146
+ bool align_flag = (align_mode == 0 && !align_corners);
125
147
for (; tid < nthreads; tid += stride) {
126
148
int out_id_h = tid / output_w;
127
149
int out_id_w = tid % output_w;
@@ -130,15 +152,22 @@ __global__ void KeBilinearInterpBw(
130
152
int channel_id = out_id_w / out_img_size;
131
153
132
154
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
133
- int in_img_idy = ratio_h * out_img_idy;
155
+ int in_img_idy = align_flag ? ratio_h * (out_img_idy + 0.5 ) - 0.5
156
+ : ratio_h * out_img_idy;
157
+ in_img_idy = (in_img_idy > 0 ) ? in_img_idy : 0 ;
134
158
int h_id = (in_img_idy < in_img_h - 1 ) ? 1 : 0 ;
135
- T h1lambda = ratio_h * out_img_idy - in_img_idy;
159
+ T h1lambda = align_flag ? ratio_h * (out_img_idy + 0.5 ) - 0.5 - in_img_idy
160
+ : ratio_h * out_img_idy - in_img_idy;
161
+
136
162
T h2lambda = 1 .f - h1lambda;
137
163
138
164
int out_img_idx = tid % out_img_w;
139
- int in_img_idx = ratio_w * out_img_idx;
165
+ int in_img_idx = align_flag ? ratio_w * (out_img_idx + 0.5 ) - 0.5
166
+ : ratio_w * out_img_idx;
167
+ in_img_idx = (in_img_idx > 0 ) ? in_img_idx : 0 ;
140
168
int w_id = (in_img_idx < in_img_w - 1 ) ? 1 : 0 ;
141
- T w1lambda = ratio_w * out_img_idx - in_img_idx;
169
+ T w1lambda = align_flag ? ratio_w * (out_img_idx + 0.5 ) - 0.5 - in_img_idx
170
+ : ratio_w * out_img_idx - in_img_idx;
142
171
T w2lambda = 1 .f - w1lambda;
143
172
144
173
T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
@@ -175,6 +204,9 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
175
204
out_w = size_data[1 ];
176
205
}
177
206
207
+ bool align_corners = ctx.Attr <bool >(" align_corners" );
208
+ int align_mode = ctx.Attr <int >(" align_mode" );
209
+
178
210
int n = input->dims ()[0 ];
179
211
int c = input->dims ()[1 ];
180
212
int in_h = input->dims ()[2 ];
@@ -188,10 +220,16 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
188
220
int in_chw = c * in_hw;
189
221
int out_chw = c * out_hw;
190
222
191
- float ratio_h =
192
- (out_h > 1 ) ? static_cast <float >(in_h - 1 ) / (out_h - 1 ) : 0 .f ;
193
- float ratio_w =
194
- (out_w > 1 ) ? static_cast <float >(in_w - 1 ) / (out_w - 1 ) : 0 .f ;
223
+ float ratio_h = 0 .f ;
224
+ float ratio_w = 0 .f ;
225
+ if (out_h > 1 ) {
226
+ ratio_h = (align_corners) ? static_cast <float >(in_h - 1 ) / (out_h - 1 )
227
+ : static_cast <float >(in_h) / out_h;
228
+ }
229
+ if (out_w > 1 ) {
230
+ ratio_w = (align_corners) ? static_cast <float >(in_w - 1 ) / (out_w - 1 )
231
+ : static_cast <float >(in_w) / out_w;
232
+ }
195
233
196
234
if (in_h == out_h && in_w == out_w) {
197
235
framework::TensorCopy (*input, ctx.GetPlace (), output);
@@ -206,12 +244,12 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
206
244
KeNearestNeighborInterpFw<
207
245
T><<<grid_dim, 512 , 0 , ctx.cuda_device_context().stream()>>> (
208
246
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
209
- out_chw, c, ratio_h, ratio_w);
247
+ out_chw, c, ratio_h, ratio_w, align_corners );
210
248
} else if (" bilinear" == interp_method) {
211
249
KeBilinearInterpFw<
212
250
T><<<grid_dim, 512 , 0 , ctx.cuda_device_context().stream()>>> (
213
251
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
214
- out_chw, c, ratio_h, ratio_w);
252
+ out_chw, c, ratio_h, ratio_w, align_corners, align_mode );
215
253
}
216
254
}
217
255
};
@@ -234,6 +272,10 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
234
272
int out_h = ctx.Attr <int >(" out_h" );
235
273
int out_w = ctx.Attr <int >(" out_w" );
236
274
auto out_size = ctx.Input <Tensor>(" OutSize" );
275
+
276
+ bool align_corners = ctx.Attr <bool >(" align_corners" );
277
+ int align_mode = ctx.Attr <int >(" align_mode" );
278
+
237
279
if (out_size != nullptr ) {
238
280
Tensor sizes;
239
281
framework::TensorCopy (*out_size, platform::CPUPlace (), &sizes);
@@ -252,10 +294,16 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
252
294
int in_chw = c * in_hw;
253
295
int out_chw = c * out_hw;
254
296
255
- float ratio_h =
256
- (out_h > 1 ) ? static_cast <float >(in_h - 1 ) / (out_h - 1 ) : 0 .f ;
257
- float ratio_w =
258
- (out_w > 1 ) ? static_cast <float >(in_w - 1 ) / (out_w - 1 ) : 0 .f ;
297
+ float ratio_h = 0 .f ;
298
+ float ratio_w = 0 .f ;
299
+ if (out_h > 1 ) {
300
+ ratio_h = (align_corners) ? static_cast <float >(in_h - 1 ) / (out_h - 1 )
301
+ : static_cast <float >(in_h) / out_h;
302
+ }
303
+ if (out_w > 1 ) {
304
+ ratio_w = (align_corners) ? static_cast <float >(in_w - 1 ) / (out_w - 1 )
305
+ : static_cast <float >(in_w) / out_w;
306
+ }
259
307
260
308
if (in_h == out_h && in_w == out_w) {
261
309
framework::TensorCopy (*output_grad, ctx.GetPlace (), input_grad);
@@ -270,12 +318,12 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
270
318
KeNearestNeighborInterpBw<
271
319
T><<<grid_dim, 512 , 0 , ctx.cuda_device_context().stream()>>> (
272
320
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
273
- out_w, n, out_chw, c, ratio_h, ratio_w);
321
+ out_w, n, out_chw, c, ratio_h, ratio_w, align_corners );
274
322
} else if (" bilinear" == interp_method) {
275
323
KeBilinearInterpBw<
276
324
T><<<grid_dim, 512 , 0 , ctx.cuda_device_context().stream()>>> (
277
325
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
278
- out_w, n, out_chw, c, ratio_h, ratio_w);
326
+ out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode );
279
327
}
280
328
}
281
329
};
0 commit comments