@@ -15,6 +15,7 @@ limitations under the License. */
15
15
#include < string>
16
16
#include < vector>
17
17
#include " paddle/fluid/framework/op_registry.h"
18
+ #include " paddle/fluid/framework/var_type.h"
18
19
#include " paddle/fluid/operators/gather.h"
19
20
#include " paddle/fluid/operators/math/math_function.h"
20
21
@@ -69,7 +70,7 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
69
70
const framework::ExecutionContext &ctx) const override {
70
71
return framework::OpKernelType (
71
72
framework::ToDataType (ctx.Input <Tensor>(" Anchors" )->type ()),
72
- platform::CPUPlace ());
73
+ ctx. device_context ());
73
74
}
74
75
};
75
76
@@ -162,7 +163,7 @@ void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes,
162
163
const T *im_info_data = im_info.data <T>();
163
164
T *boxes_data = boxes->mutable_data <T>(ctx.GetPlace ());
164
165
T im_scale = im_info_data[2 ];
165
- keep->Resize ({boxes->dims ()[0 ], 1 });
166
+ keep->Resize ({boxes->dims ()[0 ]});
166
167
min_size = std::max (min_size, 1 .0f );
167
168
int *keep_data = keep->mutable_data <int >(ctx.GetPlace ());
168
169
@@ -463,7 +464,7 @@ class GenerateProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
463
464
AddAttr<int >(" post_nms_topN" , " post_nms_topN" );
464
465
AddAttr<float >(" nms_thresh" , " nms_thres" );
465
466
AddAttr<float >(" min_size" , " min size" );
466
- AddAttr<float >(" eta" , " eta " );
467
+ AddAttr<float >(" eta" , " The parameter for adaptive NMS. " );
467
468
AddComment (R"DOC(
468
469
Generate Proposals OP
469
470
0 commit comments