Skip to content

Commit 460b539

Browse files
authored
Fix paddle.vision.ops.deform_conv2d API big Tensor (#74324)
* fix deform_conv2d * fix * fix
1 parent 7159dbf commit 460b539

8 files changed

+552
-402
lines changed

paddle/phi/kernels/cpu/deformable_conv_grad_kernel.cc

Lines changed: 113 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -20,52 +20,52 @@
2020

2121
namespace phi {
2222

23-
template <typename T>
23+
template <typename T, typename IndexT>
2424
inline void ModulatedDeformableCol2imCPUKernel(
25-
const int num_kernels,
25+
const IndexT num_kernels,
2626
const T* data_col,
2727
const T* data_offset,
2828
const T* data_mask,
29-
const int channels,
30-
const int height,
31-
const int width,
32-
const int kernel_h,
33-
const int kernel_w,
34-
const int pad_h,
35-
const int pad_w,
36-
const int stride_h,
37-
const int stride_w,
38-
const int dilation_h,
39-
const int dilation_w,
40-
const int channel_per_deformable_group,
41-
const int batch_size,
42-
const int deformable_group,
43-
const int height_col,
44-
const int width_col,
29+
const IndexT channels,
30+
const IndexT height,
31+
const IndexT width,
32+
const IndexT kernel_h,
33+
const IndexT kernel_w,
34+
const IndexT pad_h,
35+
const IndexT pad_w,
36+
const IndexT stride_h,
37+
const IndexT stride_w,
38+
const IndexT dilation_h,
39+
const IndexT dilation_w,
40+
const IndexT channel_per_deformable_group,
41+
const IndexT batch_size,
42+
const IndexT deformable_group,
43+
const IndexT height_col,
44+
const IndexT width_col,
4545
T* grad_im) {
46-
for (int thread = 0; thread < num_kernels; thread++) {
47-
const int j = (thread / width_col / height_col / batch_size) % kernel_w;
48-
const int i =
46+
for (IndexT thread = 0; thread < num_kernels; thread++) {
47+
const IndexT j = (thread / width_col / height_col / batch_size) % kernel_w;
48+
const IndexT i =
4949
(thread / width_col / height_col / batch_size / kernel_w) % kernel_h;
50-
const int c =
50+
const IndexT c =
5151
thread / width_col / height_col / batch_size / kernel_w / kernel_h;
5252

53-
const int deformable_group_index = c / channel_per_deformable_group;
53+
const IndexT deformable_group_index = c / channel_per_deformable_group;
5454

55-
int w_out = thread % width_col;
56-
int h_out = (thread / width_col) % height_col;
57-
int b = (thread / width_col / height_col) % batch_size;
58-
int w_in = w_out * stride_w - pad_w;
59-
int h_in = h_out * stride_h - pad_h;
55+
IndexT w_out = thread % width_col;
56+
IndexT h_out = (thread / width_col) % height_col;
57+
IndexT b = (thread / width_col / height_col) % batch_size;
58+
IndexT w_in = w_out * stride_w - pad_w;
59+
IndexT h_in = h_out * stride_h - pad_h;
6060

6161
const T* data_offset_ptr =
6262
data_offset + (b * deformable_group + deformable_group_index) * 2 *
6363
kernel_h * kernel_w * height_col * width_col;
64-
const int data_offset_h_ptr =
64+
const IndexT data_offset_h_ptr =
6565
((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
66-
const int data_offset_w_ptr =
66+
const IndexT data_offset_w_ptr =
6767
((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
68-
const int data_mask_hw_ptr =
68+
const IndexT data_mask_hw_ptr =
6969
((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
7070
const T offset_h = data_offset_ptr[data_offset_h_ptr];
7171
const T offset_w = data_offset_ptr[data_offset_w_ptr];
@@ -80,14 +80,14 @@ inline void ModulatedDeformableCol2imCPUKernel(
8080
const T mask = data_mask_ptr[data_mask_hw_ptr];
8181
cur_top_grad *= mask;
8282
}
83-
const int cur_h = static_cast<int>(cur_inv_h_data);
84-
const int cur_w = static_cast<int>(cur_inv_w_data);
85-
for (int dy = -2; dy <= 2; dy++) {
86-
for (int dx = -2; dx <= 2; dx++) {
83+
const IndexT cur_h = static_cast<IndexT>(cur_inv_h_data);
84+
const IndexT cur_w = static_cast<IndexT>(cur_inv_w_data);
85+
for (IndexT dy = -2; dy <= 2; dy++) {
86+
for (IndexT dx = -2; dx <= 2; dx++) {
8787
if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 &&
8888
cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
8989
abs(cur_inv_w_data - (cur_w + dx)) < 1) {
90-
int cur_bottom_grad_pos =
90+
IndexT cur_bottom_grad_pos =
9191
((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
9292
T weight = DmcnGetGradientWeight(cur_inv_h_data,
9393
cur_inv_w_data,
@@ -104,7 +104,7 @@ inline void ModulatedDeformableCol2imCPUKernel(
104104
}
105105
}
106106

107-
template <typename T, typename Context>
107+
template <typename T, typename Context, typename IndexT>
108108
void ModulatedDeformableCol2im(const Context& dev_ctx,
109109
const T* data_col,
110110
const T* data_offset,
@@ -117,70 +117,69 @@ void ModulatedDeformableCol2im(const Context& dev_ctx,
117117
const std::vector<int>& dilation,
118118
const int deformable_group,
119119
T* grad_im) {
120-
int channel_per_deformable_group =
121-
static_cast<int>(im_shape[0] / deformable_group);
122-
int num_kernels = static_cast<int>(col_shape[0] * col_shape[1] *
123-
col_shape[2] * col_shape[3]);
120+
int64_t channel_per_deformable_group = im_shape[0] / deformable_group;
121+
int64_t num_kernels =
122+
col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3];
124123

125-
ModulatedDeformableCol2imCPUKernel(num_kernels,
126-
data_col,
127-
data_offset,
128-
data_mask,
129-
im_shape[0],
130-
im_shape[1],
131-
im_shape[2],
132-
kernel_shape[2],
133-
kernel_shape[3],
134-
pad[0],
135-
pad[1],
136-
stride[0],
137-
stride[1],
138-
dilation[0],
139-
dilation[1],
140-
channel_per_deformable_group,
141-
col_shape[1],
142-
deformable_group,
143-
col_shape[2],
144-
col_shape[3],
145-
grad_im);
124+
ModulatedDeformableCol2imCPUKernel<T, IndexT>(num_kernels,
125+
data_col,
126+
data_offset,
127+
data_mask,
128+
im_shape[0],
129+
im_shape[1],
130+
im_shape[2],
131+
kernel_shape[2],
132+
kernel_shape[3],
133+
pad[0],
134+
pad[1],
135+
stride[0],
136+
stride[1],
137+
dilation[0],
138+
dilation[1],
139+
channel_per_deformable_group,
140+
col_shape[1],
141+
deformable_group,
142+
col_shape[2],
143+
col_shape[3],
144+
grad_im);
146145
}
147146

148-
template <typename T>
147+
template <typename T, typename IndexT>
149148
void ModulatedDeformableCol2imCoordCPUKernel(
150-
const int num_kernels,
149+
const IndexT num_kernels,
151150
const T* data_col,
152151
const T* data_im,
153152
const T* data_offset,
154153
const T* data_mask,
155-
const int channels,
156-
const int height,
157-
const int width,
158-
const int kernel_h,
159-
const int kernel_w,
160-
const int pad_h,
161-
const int pad_w,
162-
const int stride_h,
163-
const int stride_w,
164-
const int dilation_h,
165-
const int dilation_w,
166-
const int channel_per_deformable_group,
167-
const int batch_size,
168-
const int offset_channels,
169-
const int deformable_group,
170-
const int height_col,
171-
const int width_col,
154+
const IndexT channels,
155+
const IndexT height,
156+
const IndexT width,
157+
const IndexT kernel_h,
158+
const IndexT kernel_w,
159+
const IndexT pad_h,
160+
const IndexT pad_w,
161+
const IndexT stride_h,
162+
const IndexT stride_w,
163+
const IndexT dilation_h,
164+
const IndexT dilation_w,
165+
const IndexT channel_per_deformable_group,
166+
const IndexT batch_size,
167+
const IndexT offset_channels,
168+
const IndexT deformable_group,
169+
const IndexT height_col,
170+
const IndexT width_col,
172171
T* grad_offset,
173172
T* grad_mask) {
174-
for (int i = 0; i < num_kernels; i++) {
173+
for (IndexT i = 0; i < num_kernels; i++) {
175174
T val = 0, mval = 0;
176-
const int w = i % width_col;
177-
const int h = (i / width_col) % height_col;
178-
const int c = (i / width_col / height_col) % offset_channels;
179-
const int b = (i / width_col / height_col) / offset_channels;
175+
const IndexT w = i % width_col;
176+
const IndexT h = (i / width_col) % height_col;
177+
const IndexT c = (i / width_col / height_col) % offset_channels;
178+
const IndexT b = (i / width_col / height_col) / offset_channels;
180179

181-
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
182-
const int col_step = kernel_h * kernel_w;
183-
int cnt = 0;
180+
const IndexT deformable_group_index = c / (2 * kernel_h * kernel_w);
181+
const IndexT col_step = kernel_h * kernel_w;
182+
IndexT cnt = 0;
184183
const T* data_col_ptr = data_col + deformable_group_index *
185184
channel_per_deformable_group *
186185
batch_size * width_col * height_col;
@@ -197,24 +196,25 @@ void ModulatedDeformableCol2imCoordCPUKernel(
197196
kernel_h * kernel_w * height_col * width_col
198197
: nullptr;
199198

200-
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
199+
const IndexT offset_c =
200+
c - deformable_group_index * 2 * kernel_h * kernel_w;
201201

202-
for (int col_c = offset_c / 2; col_c < channel_per_deformable_group;
202+
for (IndexT col_c = offset_c / 2; col_c < channel_per_deformable_group;
203203
col_c += col_step) {
204-
const int col_pos =
204+
const IndexT col_pos =
205205
(((col_c * batch_size + b) * height_col) + h) * width_col + w;
206-
const int bp_dir = offset_c % 2;
206+
const IndexT bp_dir = offset_c % 2;
207207

208-
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
209-
int i =
208+
IndexT j = (col_pos / width_col / height_col / batch_size) % kernel_w;
209+
IndexT i =
210210
(col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
211-
int w_out = col_pos % width_col;
212-
int h_out = (col_pos / width_col) % height_col;
213-
int w_in = w_out * stride_w - pad_w;
214-
int h_in = h_out * stride_h - pad_h;
215-
const int data_offset_h_ptr =
211+
IndexT w_out = col_pos % width_col;
212+
IndexT h_out = (col_pos / width_col) % height_col;
213+
IndexT w_in = w_out * stride_w - pad_w;
214+
IndexT h_in = h_out * stride_h - pad_h;
215+
const IndexT data_offset_h_ptr =
216216
(((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
217-
const int data_offset_w_ptr =
217+
const IndexT data_offset_w_ptr =
218218
(((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
219219
w_out);
220220
const T offset_h = data_offset_ptr[data_offset_h_ptr];
@@ -241,7 +241,7 @@ void ModulatedDeformableCol2imCoordCPUKernel(
241241
width,
242242
bp_dir);
243243
if (data_mask_ptr) {
244-
const int data_mask_hw_ptr =
244+
const IndexT data_mask_hw_ptr =
245245
(((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
246246
const T mask = data_mask_ptr[data_mask_hw_ptr];
247247
val += weight * data_col_ptr[col_pos] * mask;
@@ -262,7 +262,7 @@ void ModulatedDeformableCol2imCoordCPUKernel(
262262
}
263263
}
264264

265-
template <typename T, typename Context>
265+
template <typename T, typename Context, typename IndexT>
266266
void ModulatedDeformableCol2imCoord(const Context& dev_ctx,
267267
const T* data_col,
268268
const T* data_im,
@@ -277,13 +277,11 @@ void ModulatedDeformableCol2imCoord(const Context& dev_ctx,
277277
const int deformable_groups,
278278
T* grad_offset,
279279
T* grad_mask) {
280-
int num_kernels =
281-
static_cast<int>(2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] *
282-
col_shape[2] * col_shape[3] * deformable_groups);
283-
int channel_per_deformable_group =
284-
static_cast<int>(col_shape[0] / deformable_groups);
280+
int64_t num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] *
281+
col_shape[2] * col_shape[3] * deformable_groups;
282+
int64_t channel_per_deformable_group = col_shape[0] / deformable_groups;
285283

286-
ModulatedDeformableCol2imCoordCPUKernel(
284+
ModulatedDeformableCol2imCoordCPUKernel<T, IndexT>(
287285
num_kernels,
288286
data_col,
289287
data_im,
@@ -310,15 +308,15 @@ void ModulatedDeformableCol2imCoord(const Context& dev_ctx,
310308
grad_mask);
311309
}
312310

313-
template <typename T, typename Context>
311+
template <typename T, typename Context, typename IndexT>
314312
void FilterGradAddup(const Context& dev_ctx,
315-
const int nthreads,
316-
const int n,
317-
const int height,
318-
const int width,
313+
const int64_t nthreads,
314+
const int64_t n,
315+
const int64_t height,
316+
const int64_t width,
319317
const T* dweight_3d,
320318
T* filter_grad) {
321-
for (int i = 0; i < nthreads; i++) {
319+
for (IndexT i = 0; i < nthreads; i++) {
322320
filter_grad[i] = filter_grad[i] + dweight_3d[i];
323321
}
324322
}

0 commit comments

Comments
 (0)