20
20
21
21
namespace phi {
22
22
23
- template <typename T>
23
+ template <typename T, typename IndexT >
24
24
inline void ModulatedDeformableCol2imCPUKernel (
25
- const int num_kernels,
25
+ const IndexT num_kernels,
26
26
const T* data_col,
27
27
const T* data_offset,
28
28
const T* data_mask,
29
- const int channels,
30
- const int height,
31
- const int width,
32
- const int kernel_h,
33
- const int kernel_w,
34
- const int pad_h,
35
- const int pad_w,
36
- const int stride_h,
37
- const int stride_w,
38
- const int dilation_h,
39
- const int dilation_w,
40
- const int channel_per_deformable_group,
41
- const int batch_size,
42
- const int deformable_group,
43
- const int height_col,
44
- const int width_col,
29
+ const IndexT channels,
30
+ const IndexT height,
31
+ const IndexT width,
32
+ const IndexT kernel_h,
33
+ const IndexT kernel_w,
34
+ const IndexT pad_h,
35
+ const IndexT pad_w,
36
+ const IndexT stride_h,
37
+ const IndexT stride_w,
38
+ const IndexT dilation_h,
39
+ const IndexT dilation_w,
40
+ const IndexT channel_per_deformable_group,
41
+ const IndexT batch_size,
42
+ const IndexT deformable_group,
43
+ const IndexT height_col,
44
+ const IndexT width_col,
45
45
T* grad_im) {
46
- for (int thread = 0 ; thread < num_kernels; thread++) {
47
- const int j = (thread / width_col / height_col / batch_size) % kernel_w;
48
- const int i =
46
+ for (IndexT thread = 0 ; thread < num_kernels; thread++) {
47
+ const IndexT j = (thread / width_col / height_col / batch_size) % kernel_w;
48
+ const IndexT i =
49
49
(thread / width_col / height_col / batch_size / kernel_w) % kernel_h;
50
- const int c =
50
+ const IndexT c =
51
51
thread / width_col / height_col / batch_size / kernel_w / kernel_h;
52
52
53
- const int deformable_group_index = c / channel_per_deformable_group;
53
+ const IndexT deformable_group_index = c / channel_per_deformable_group;
54
54
55
- int w_out = thread % width_col;
56
- int h_out = (thread / width_col) % height_col;
57
- int b = (thread / width_col / height_col) % batch_size;
58
- int w_in = w_out * stride_w - pad_w;
59
- int h_in = h_out * stride_h - pad_h;
55
+ IndexT w_out = thread % width_col;
56
+ IndexT h_out = (thread / width_col) % height_col;
57
+ IndexT b = (thread / width_col / height_col) % batch_size;
58
+ IndexT w_in = w_out * stride_w - pad_w;
59
+ IndexT h_in = h_out * stride_h - pad_h;
60
60
61
61
const T* data_offset_ptr =
62
62
data_offset + (b * deformable_group + deformable_group_index) * 2 *
63
63
kernel_h * kernel_w * height_col * width_col;
64
- const int data_offset_h_ptr =
64
+ const IndexT data_offset_h_ptr =
65
65
((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
66
- const int data_offset_w_ptr =
66
+ const IndexT data_offset_w_ptr =
67
67
((2 * (i * kernel_w + j) + 1 ) * height_col + h_out) * width_col + w_out;
68
- const int data_mask_hw_ptr =
68
+ const IndexT data_mask_hw_ptr =
69
69
((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
70
70
const T offset_h = data_offset_ptr[data_offset_h_ptr];
71
71
const T offset_w = data_offset_ptr[data_offset_w_ptr];
@@ -80,14 +80,14 @@ inline void ModulatedDeformableCol2imCPUKernel(
80
80
const T mask = data_mask_ptr[data_mask_hw_ptr];
81
81
cur_top_grad *= mask;
82
82
}
83
- const int cur_h = static_cast <int >(cur_inv_h_data);
84
- const int cur_w = static_cast <int >(cur_inv_w_data);
85
- for (int dy = -2 ; dy <= 2 ; dy++) {
86
- for (int dx = -2 ; dx <= 2 ; dx++) {
83
+ const IndexT cur_h = static_cast <IndexT >(cur_inv_h_data);
84
+ const IndexT cur_w = static_cast <IndexT >(cur_inv_w_data);
85
+ for (IndexT dy = -2 ; dy <= 2 ; dy++) {
86
+ for (IndexT dx = -2 ; dx <= 2 ; dx++) {
87
87
if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 &&
88
88
cur_w + dx < width && abs (cur_inv_h_data - (cur_h + dy)) < 1 &&
89
89
abs (cur_inv_w_data - (cur_w + dx)) < 1 ) {
90
- int cur_bottom_grad_pos =
90
+ IndexT cur_bottom_grad_pos =
91
91
((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
92
92
T weight = DmcnGetGradientWeight (cur_inv_h_data,
93
93
cur_inv_w_data,
@@ -104,7 +104,7 @@ inline void ModulatedDeformableCol2imCPUKernel(
104
104
}
105
105
}
106
106
107
- template <typename T, typename Context>
107
+ template <typename T, typename Context, typename IndexT >
108
108
void ModulatedDeformableCol2im (const Context& dev_ctx,
109
109
const T* data_col,
110
110
const T* data_offset,
@@ -117,70 +117,69 @@ void ModulatedDeformableCol2im(const Context& dev_ctx,
117
117
const std::vector<int >& dilation,
118
118
const int deformable_group,
119
119
T* grad_im) {
120
- int channel_per_deformable_group =
121
- static_cast <int >(im_shape[0 ] / deformable_group);
122
- int num_kernels = static_cast <int >(col_shape[0 ] * col_shape[1 ] *
123
- col_shape[2 ] * col_shape[3 ]);
120
+ int64_t channel_per_deformable_group = im_shape[0 ] / deformable_group;
121
+ int64_t num_kernels =
122
+ col_shape[0 ] * col_shape[1 ] * col_shape[2 ] * col_shape[3 ];
124
123
125
- ModulatedDeformableCol2imCPUKernel (num_kernels,
126
- data_col,
127
- data_offset,
128
- data_mask,
129
- im_shape[0 ],
130
- im_shape[1 ],
131
- im_shape[2 ],
132
- kernel_shape[2 ],
133
- kernel_shape[3 ],
134
- pad[0 ],
135
- pad[1 ],
136
- stride[0 ],
137
- stride[1 ],
138
- dilation[0 ],
139
- dilation[1 ],
140
- channel_per_deformable_group,
141
- col_shape[1 ],
142
- deformable_group,
143
- col_shape[2 ],
144
- col_shape[3 ],
145
- grad_im);
124
+ ModulatedDeformableCol2imCPUKernel<T, IndexT> (num_kernels,
125
+ data_col,
126
+ data_offset,
127
+ data_mask,
128
+ im_shape[0 ],
129
+ im_shape[1 ],
130
+ im_shape[2 ],
131
+ kernel_shape[2 ],
132
+ kernel_shape[3 ],
133
+ pad[0 ],
134
+ pad[1 ],
135
+ stride[0 ],
136
+ stride[1 ],
137
+ dilation[0 ],
138
+ dilation[1 ],
139
+ channel_per_deformable_group,
140
+ col_shape[1 ],
141
+ deformable_group,
142
+ col_shape[2 ],
143
+ col_shape[3 ],
144
+ grad_im);
146
145
}
147
146
148
- template <typename T>
147
+ template <typename T, typename IndexT >
149
148
void ModulatedDeformableCol2imCoordCPUKernel (
150
- const int num_kernels,
149
+ const IndexT num_kernels,
151
150
const T* data_col,
152
151
const T* data_im,
153
152
const T* data_offset,
154
153
const T* data_mask,
155
- const int channels,
156
- const int height,
157
- const int width,
158
- const int kernel_h,
159
- const int kernel_w,
160
- const int pad_h,
161
- const int pad_w,
162
- const int stride_h,
163
- const int stride_w,
164
- const int dilation_h,
165
- const int dilation_w,
166
- const int channel_per_deformable_group,
167
- const int batch_size,
168
- const int offset_channels,
169
- const int deformable_group,
170
- const int height_col,
171
- const int width_col,
154
+ const IndexT channels,
155
+ const IndexT height,
156
+ const IndexT width,
157
+ const IndexT kernel_h,
158
+ const IndexT kernel_w,
159
+ const IndexT pad_h,
160
+ const IndexT pad_w,
161
+ const IndexT stride_h,
162
+ const IndexT stride_w,
163
+ const IndexT dilation_h,
164
+ const IndexT dilation_w,
165
+ const IndexT channel_per_deformable_group,
166
+ const IndexT batch_size,
167
+ const IndexT offset_channels,
168
+ const IndexT deformable_group,
169
+ const IndexT height_col,
170
+ const IndexT width_col,
172
171
T* grad_offset,
173
172
T* grad_mask) {
174
- for (int i = 0 ; i < num_kernels; i++) {
173
+ for (IndexT i = 0 ; i < num_kernels; i++) {
175
174
T val = 0 , mval = 0 ;
176
- const int w = i % width_col;
177
- const int h = (i / width_col) % height_col;
178
- const int c = (i / width_col / height_col) % offset_channels;
179
- const int b = (i / width_col / height_col) / offset_channels;
175
+ const IndexT w = i % width_col;
176
+ const IndexT h = (i / width_col) % height_col;
177
+ const IndexT c = (i / width_col / height_col) % offset_channels;
178
+ const IndexT b = (i / width_col / height_col) / offset_channels;
180
179
181
- const int deformable_group_index = c / (2 * kernel_h * kernel_w);
182
- const int col_step = kernel_h * kernel_w;
183
- int cnt = 0 ;
180
+ const IndexT deformable_group_index = c / (2 * kernel_h * kernel_w);
181
+ const IndexT col_step = kernel_h * kernel_w;
182
+ IndexT cnt = 0 ;
184
183
const T* data_col_ptr = data_col + deformable_group_index *
185
184
channel_per_deformable_group *
186
185
batch_size * width_col * height_col;
@@ -197,24 +196,25 @@ void ModulatedDeformableCol2imCoordCPUKernel(
197
196
kernel_h * kernel_w * height_col * width_col
198
197
: nullptr ;
199
198
200
- const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
199
+ const IndexT offset_c =
200
+ c - deformable_group_index * 2 * kernel_h * kernel_w;
201
201
202
- for (int col_c = offset_c / 2 ; col_c < channel_per_deformable_group;
202
+ for (IndexT col_c = offset_c / 2 ; col_c < channel_per_deformable_group;
203
203
col_c += col_step) {
204
- const int col_pos =
204
+ const IndexT col_pos =
205
205
(((col_c * batch_size + b) * height_col) + h) * width_col + w;
206
- const int bp_dir = offset_c % 2 ;
206
+ const IndexT bp_dir = offset_c % 2 ;
207
207
208
- int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
209
- int i =
208
+ IndexT j = (col_pos / width_col / height_col / batch_size) % kernel_w;
209
+ IndexT i =
210
210
(col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
211
- int w_out = col_pos % width_col;
212
- int h_out = (col_pos / width_col) % height_col;
213
- int w_in = w_out * stride_w - pad_w;
214
- int h_in = h_out * stride_h - pad_h;
215
- const int data_offset_h_ptr =
211
+ IndexT w_out = col_pos % width_col;
212
+ IndexT h_out = (col_pos / width_col) % height_col;
213
+ IndexT w_in = w_out * stride_w - pad_w;
214
+ IndexT h_in = h_out * stride_h - pad_h;
215
+ const IndexT data_offset_h_ptr =
216
216
(((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
217
- const int data_offset_w_ptr =
217
+ const IndexT data_offset_w_ptr =
218
218
(((2 * (i * kernel_w + j) + 1 ) * height_col + h_out) * width_col +
219
219
w_out);
220
220
const T offset_h = data_offset_ptr[data_offset_h_ptr];
@@ -241,7 +241,7 @@ void ModulatedDeformableCol2imCoordCPUKernel(
241
241
width,
242
242
bp_dir);
243
243
if (data_mask_ptr) {
244
- const int data_mask_hw_ptr =
244
+ const IndexT data_mask_hw_ptr =
245
245
(((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
246
246
const T mask = data_mask_ptr[data_mask_hw_ptr];
247
247
val += weight * data_col_ptr[col_pos] * mask;
@@ -262,7 +262,7 @@ void ModulatedDeformableCol2imCoordCPUKernel(
262
262
}
263
263
}
264
264
265
- template <typename T, typename Context>
265
+ template <typename T, typename Context, typename IndexT >
266
266
void ModulatedDeformableCol2imCoord (const Context& dev_ctx,
267
267
const T* data_col,
268
268
const T* data_im,
@@ -277,13 +277,11 @@ void ModulatedDeformableCol2imCoord(const Context& dev_ctx,
277
277
const int deformable_groups,
278
278
T* grad_offset,
279
279
T* grad_mask) {
280
- int num_kernels =
281
- static_cast <int >(2 * kernel_shape[2 ] * kernel_shape[3 ] * col_shape[1 ] *
282
- col_shape[2 ] * col_shape[3 ] * deformable_groups);
283
- int channel_per_deformable_group =
284
- static_cast <int >(col_shape[0 ] / deformable_groups);
280
+ int64_t num_kernels = 2 * kernel_shape[2 ] * kernel_shape[3 ] * col_shape[1 ] *
281
+ col_shape[2 ] * col_shape[3 ] * deformable_groups;
282
+ int64_t channel_per_deformable_group = col_shape[0 ] / deformable_groups;
285
283
286
- ModulatedDeformableCol2imCoordCPUKernel (
284
+ ModulatedDeformableCol2imCoordCPUKernel<T, IndexT> (
287
285
num_kernels,
288
286
data_col,
289
287
data_im,
@@ -310,15 +308,15 @@ void ModulatedDeformableCol2imCoord(const Context& dev_ctx,
310
308
grad_mask);
311
309
}
312
310
313
- template <typename T, typename Context>
311
+ template <typename T, typename Context, typename IndexT >
314
312
void FilterGradAddup (const Context& dev_ctx,
315
- const int nthreads,
316
- const int n,
317
- const int height,
318
- const int width,
313
+ const int64_t nthreads,
314
+ const int64_t n,
315
+ const int64_t height,
316
+ const int64_t width,
319
317
const T* dweight_3d,
320
318
T* filter_grad) {
321
- for (int i = 0 ; i < nthreads; i++) {
319
+ for (IndexT i = 0 ; i < nthreads; i++) {
322
320
filter_grad[i] = filter_grad[i] + dweight_3d[i];
323
321
}
324
322
}
0 commit comments