@@ -133,53 +133,47 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel<T> {
133
133
auto * in = ctx.Input <framework::Tensor>(" X" );
134
134
auto * rois = ctx.Input <framework::Tensor>(" ROIs" );
135
135
auto * argmax = ctx.Input <framework::Tensor>(" Argmax" );
136
-
137
136
auto * out_grad =
138
137
ctx.Input <framework::Tensor>(framework::GradVarName (" Out" ));
139
- auto * x_grad = ctx.Output <framework::Tensor>(framework::GradVarName (" X" ));
138
+ auto * in_grad = ctx.Output <framework::Tensor>(framework::GradVarName (" X" ));
140
139
141
140
auto pooled_height = ctx.Attr <int >(" pooled_height" );
142
141
auto pooled_width = ctx.Attr <int >(" pooled_width" );
143
142
144
- if (x_grad) {
145
- int channels = in->dims ()[1 ];
146
- auto in_stride = framework::stride (in->dims ());
147
- auto roi_stride = framework::stride (rois->dims ());
148
-
143
+ if (in_grad) {
149
144
const int64_t * rois_data = rois->data <int64_t >();
150
- int rois_num = rois-> dims ()[ 0 ] ;
151
-
152
- T* x_grad_data = x_grad ->mutable_data <T>(ctx.GetPlace ());
145
+ const T* out_grad_data = out_grad-> data <T>() ;
146
+ const int64_t * argmax_data = argmax-> data < int64_t >();
147
+ T* in_grad_data = in_grad ->mutable_data <T>(ctx.GetPlace ());
153
148
math::SetConstant<Place, T> set_zero;
154
- set_zero (ctx.device_context (), x_grad , static_cast <T>(0 ));
149
+ set_zero (ctx.device_context (), in_grad , static_cast <T>(0 ));
155
150
156
- size_t roi_offset = roi_stride[0 ];
157
- size_t batch_offset = in_stride[0 ];
158
- size_t channel_offset = in_stride[1 ];
151
+ auto in_stride = framework::stride (in->dims ());
152
+ auto argmax_stride = framework::stride (argmax->dims ());
153
+ auto roi_stride = framework::stride (rois->dims ());
154
+ auto out_stride = framework::stride (out_grad->dims ());
159
155
160
- const T* out_grad_data = out_grad->data <T>();
161
- size_t pool_channel_offset = pooled_height * pooled_width;
162
- const int64_t * argmax_data = argmax->data <int64_t >();
156
+ int rois_num = rois->dims ()[0 ];
157
+ int channels = in->dims ()[1 ];
163
158
164
- for (size_t n = 0 ; n < rois_num; ++n) {
165
- size_t roi_batch_idx = rois_data[0 ];
166
- T* batch_grad_data = x_grad_data + batch_offset * roi_batch_idx ;
159
+ for (int n = 0 ; n < rois_num; ++n) {
160
+ int roi_batch_idx = rois_data[0 ];
161
+ T* batch_grad_data = in_grad_data + roi_batch_idx * in_stride[ 0 ] ;
167
162
for (int c = 0 ; c < channels; ++c) {
168
163
for (int ph = 0 ; ph < pooled_height; ++ph) {
169
164
for (int pw = 0 ; pw < pooled_width; ++pw) {
170
- size_t pool_index = ph * pooled_width + pw;
171
-
165
+ int pool_index = ph * pooled_width + pw;
172
166
if (argmax_data[pool_index] >= 0 ) {
173
- size_t index = static_cast < size_t >( argmax_data[pool_index]) ;
167
+ auto index = argmax_data[pool_index];
174
168
batch_grad_data[index] += out_grad_data[pool_index];
175
169
}
176
170
}
177
171
}
178
- batch_grad_data += channel_offset ;
179
- out_grad_data += pool_channel_offset ;
180
- argmax_data += pool_channel_offset ;
172
+ batch_grad_data += in_stride[ 1 ] ;
173
+ out_grad_data += out_stride[ 1 ] ;
174
+ argmax_data += argmax_stride[ 1 ] ;
181
175
}
182
- rois_data += roi_offset ;
176
+ rois_data += roi_stride[ 0 ] ;
183
177
}
184
178
}
185
179
}
0 commit comments