Skip to content

Commit fd4c4df

Browse files
authored
Cuda speed for generate_proposals_op. (#13596)
* Add CUDA implementation for generate_proposals_op. * Clean code. * Update code.
1 parent 18e5dcc commit fd4c4df

File tree

4 files changed

+461
-6
lines changed

4 files changed

+461
-6
lines changed

paddle/fluid/operators/detection/CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,13 @@ detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc
3030
polygon_box_transform_op.cu)
3131
detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
3232
detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc)
33-
detection_library(generate_proposals_op SRCS generate_proposals_op.cc)
33+
34+
if(WITH_GPU)
35+
detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub)
36+
else()
37+
detection_library(generate_proposals_op SRCS generate_proposals_op.cc)
38+
endif()
39+
3440
detection_library(roi_perspective_transform_op SRCS roi_perspective_transform_op.cc roi_perspective_transform_op.cu)
3541
#Export local libraries to parent
3642
set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE)

paddle/fluid/operators/detection/generate_proposals_op.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include <string>
1616
#include <vector>
1717
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/framework/var_type.h"
1819
#include "paddle/fluid/operators/gather.h"
1920
#include "paddle/fluid/operators/math/math_function.h"
2021

@@ -69,7 +70,7 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
6970
const framework::ExecutionContext &ctx) const override {
7071
return framework::OpKernelType(
7172
framework::ToDataType(ctx.Input<Tensor>("Anchors")->type()),
72-
platform::CPUPlace());
73+
ctx.device_context());
7374
}
7475
};
7576

@@ -162,7 +163,7 @@ void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes,
162163
const T *im_info_data = im_info.data<T>();
163164
T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
164165
T im_scale = im_info_data[2];
165-
keep->Resize({boxes->dims()[0], 1});
166+
keep->Resize({boxes->dims()[0]});
166167
min_size = std::max(min_size, 1.0f);
167168
int *keep_data = keep->mutable_data<int>(ctx.GetPlace());
168169

@@ -463,7 +464,7 @@ class GenerateProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
463464
AddAttr<int>("post_nms_topN", "post_nms_topN");
464465
AddAttr<float>("nms_thresh", "nms_thres");
465466
AddAttr<float>("min_size", "min size");
466-
AddAttr<float>("eta", "eta");
467+
AddAttr<float>("eta", "The parameter for adaptive NMS.");
467468
AddComment(R"DOC(
468469
Generate Proposals OP
469470

0 commit comments

Comments
 (0)