Skip to content

Commit 30cc8b7

Browse files
authored
Merge pull request #15554 from heavengate/yolo_loss_darknet
Yolo loss darknet
2 parents 1a252f4 + 23d34d1 commit 30cc8b7

File tree

8 files changed

+694
-691
lines changed

8 files changed

+694
-691
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes',
324324
paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,))
325325
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name', 'axis'], varargs=None, keywords=None, defaults=('encode_center_size', True, None, 0))
326326
paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
327-
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None))
327+
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,))
328328
paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None))
329329
paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
330330
paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1))

paddle/fluid/operators/detection/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc
3131
polygon_box_transform_op.cu)
3232
detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
3333
detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc)
34+
detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc)
3435

3536
if(WITH_GPU)
3637
detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub)

paddle/fluid/operators/yolov3_loss_op.cc renamed to paddle/fluid/operators/detection/yolov3_loss_op.cc

Lines changed: 74 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
See the License for the specific language governing permissions and
1010
limitations under the License. */
1111

12-
#include "paddle/fluid/operators/yolov3_loss_op.h"
12+
#include "paddle/fluid/operators/detection/yolov3_loss_op.h"
1313
#include "paddle/fluid/framework/op_registry.h"
1414

1515
namespace paddle {
@@ -29,23 +29,33 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
2929
"Input(GTLabel) of Yolov3LossOp should not be null.");
3030
PADDLE_ENFORCE(ctx->HasOutput("Loss"),
3131
"Output(Loss) of Yolov3LossOp should not be null.");
32+
PADDLE_ENFORCE(
33+
ctx->HasOutput("ObjectnessMask"),
34+
"Output(ObjectnessMask) of Yolov3LossOp should not be null.");
35+
PADDLE_ENFORCE(ctx->HasOutput("GTMatchMask"),
36+
"Output(GTMatchMask) of Yolov3LossOp should not be null.");
3237

3338
auto dim_x = ctx->GetInputDim("X");
3439
auto dim_gtbox = ctx->GetInputDim("GTBox");
3540
auto dim_gtlabel = ctx->GetInputDim("GTLabel");
3641
auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors");
42+
int anchor_num = anchors.size() / 2;
43+
auto anchor_mask = ctx->Attrs().Get<std::vector<int>>("anchor_mask");
44+
int mask_num = anchor_mask.size();
3745
auto class_num = ctx->Attrs().Get<int>("class_num");
46+
3847
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor.");
3948
PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3],
4049
"Input(X) dim[3] and dim[4] should be euqal.");
41-
PADDLE_ENFORCE_EQ(dim_x[1], anchors.size() / 2 * (5 + class_num),
42-
"Input(X) dim[1] should be equal to (anchor_number * (5 "
43-
"+ class_num)).");
50+
PADDLE_ENFORCE_EQ(
51+
dim_x[1], mask_num * (5 + class_num),
52+
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
53+
"+ class_num)).");
4454
PADDLE_ENFORCE_EQ(dim_gtbox.size(), 3,
4555
"Input(GTBox) should be a 3-D tensor");
4656
PADDLE_ENFORCE_EQ(dim_gtbox[2], 4, "Input(GTBox) dim[2] should be 5");
4757
PADDLE_ENFORCE_EQ(dim_gtlabel.size(), 2,
48-
"Input(GTBox) should be a 2-D tensor");
58+
"Input(GTLabel) should be a 2-D tensor");
4959
PADDLE_ENFORCE_EQ(dim_gtlabel[0], dim_gtbox[0],
5060
"Input(GTBox) and Input(GTLabel) dim[0] should be same");
5161
PADDLE_ENFORCE_EQ(dim_gtlabel[1], dim_gtbox[1],
@@ -54,11 +64,22 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
5464
"Attr(anchors) length should be greater then 0.");
5565
PADDLE_ENFORCE_EQ(anchors.size() % 2, 0,
5666
"Attr(anchors) length should be even integer.");
67+
for (size_t i = 0; i < anchor_mask.size(); i++) {
68+
PADDLE_ENFORCE_LT(
69+
anchor_mask[i], anchor_num,
70+
"Attr(anchor_mask) should not crossover Attr(anchors).");
71+
}
5772
PADDLE_ENFORCE_GT(class_num, 0,
5873
"Attr(class_num) should be an integer greater then 0.");
5974

60-
std::vector<int64_t> dim_out({1});
75+
std::vector<int64_t> dim_out({dim_x[0]});
6176
ctx->SetOutputDim("Loss", framework::make_ddim(dim_out));
77+
78+
std::vector<int64_t> dim_obj_mask({dim_x[0], mask_num, dim_x[2], dim_x[3]});
79+
ctx->SetOutputDim("ObjectnessMask", framework::make_ddim(dim_obj_mask));
80+
81+
std::vector<int64_t> dim_gt_match_mask({dim_gtbox[0], dim_gtbox[1]});
82+
ctx->SetOutputDim("GTMatchMask", framework::make_ddim(dim_gt_match_mask));
6283
}
6384

6485
protected:
@@ -73,11 +94,11 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
7394
public:
7495
void Make() override {
7596
AddInput("X",
76-
"The input tensor of YOLO v3 loss operator, "
97+
"The input tensor of YOLOv3 loss operator, "
7798
"This is a 4-D tensor with shape of [N, C, H, W]."
7899
"H and W should be same, and the second dimention(C) stores"
79100
"box locations, confidence score and classification one-hot"
80-
"key of each anchor box");
101+
"keys of each anchor box");
81102
AddInput("GTBox",
82103
"The input tensor of ground truth boxes, "
83104
"This is a 3-D tensor with shape of [N, max_box_num, 5], "
@@ -89,32 +110,39 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
89110
AddInput("GTLabel",
90111
"The input tensor of ground truth label, "
91112
"This is a 2-D tensor with shape of [N, max_box_num], "
92-
"and each element shoudl be an integer to indicate the "
113+
"and each element should be an integer to indicate the "
93114
"box class id.");
94115
AddOutput("Loss",
95116
"The output yolov3 loss tensor, "
96-
"This is a 1-D tensor with shape of [1]");
117+
"This is a 1-D tensor with shape of [N]");
118+
AddOutput("ObjectnessMask",
119+
"This is an intermediate tensor with shape of [N, M, H, W], "
120+
"M is the number of anchor masks. This parameter caches the "
121+
"mask for calculate objectness loss in gradient kernel.")
122+
.AsIntermediate();
123+
AddOutput("GTMatchMask",
124+
"This is an intermediate tensor with shape of [N, B], "
125+
"B is the max box number of GT boxes. This parameter caches "
126+
"matched mask index of each GT boxes for gradient calculate.")
127+
.AsIntermediate();
97128

98129
AddAttr<int>("class_num", "The number of classes to predict.");
99130
AddAttr<std::vector<int>>("anchors",
100131
"The anchor width and height, "
101-
"it will be parsed pair by pair.");
132+
"it will be parsed pair by pair.")
133+
.SetDefault(std::vector<int>{});
134+
AddAttr<std::vector<int>>("anchor_mask",
135+
"The mask index of anchors used in "
136+
"current YOLOv3 loss calculation.")
137+
.SetDefault(std::vector<int>{});
138+
AddAttr<int>("downsample_ratio",
139+
"The downsample ratio from network input to YOLOv3 loss "
140+
"input, so 32, 16, 8 should be set for the first, second, "
141+
"and thrid YOLOv3 loss operators.")
142+
.SetDefault(32);
102143
AddAttr<float>("ignore_thresh",
103-
"The ignore threshold to ignore confidence loss.");
104-
AddAttr<float>("loss_weight_xy", "The weight of x, y location loss.")
105-
.SetDefault(1.0);
106-
AddAttr<float>("loss_weight_wh", "The weight of w, h location loss.")
107-
.SetDefault(1.0);
108-
AddAttr<float>(
109-
"loss_weight_conf_target",
110-
"The weight of confidence score loss in locations with target object.")
111-
.SetDefault(1.0);
112-
AddAttr<float>("loss_weight_conf_notarget",
113-
"The weight of confidence score loss in locations without "
114-
"target object.")
115-
.SetDefault(1.0);
116-
AddAttr<float>("loss_weight_class", "The weight of classification loss.")
117-
.SetDefault(1.0);
144+
"The ignore threshold to ignore confidence loss.")
145+
.SetDefault(0.7);
118146
AddComment(R"DOC(
119147
This operator generate yolov3 loss by given predict result and ground
120148
truth boxes.
@@ -147,17 +175,28 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
147175
thresh, the confidence score loss of this anchor box will be ignored.
148176
149177
Therefore, the yolov3 loss consist of three major parts, box location loss,
150-
confidence score loss, and classification loss. The MSE loss is used for
151-
box location, and binary cross entropy loss is used for confidence score
152-
loss and classification loss.
178+
confidence score loss, and classification loss. The L2 loss is used for
179+
box coordinates (w, h), and sigmoid cross entropy loss is used for box
180+
coordinates (x, y), confidence score loss and classification loss.
181+
182+
Each groud truth box find a best matching anchor box in all anchors,
183+
prediction of this anchor box will incur all three parts of losses, and
184+
prediction of anchor boxes with no GT box matched will only incur objectness
185+
loss.
186+
187+
In order to trade off box coordinate losses between big boxes and small
188+
boxes, box coordinate losses will be mutiplied by scale weight, which is
189+
calculated as follow.
190+
191+
$$
192+
weight_{box} = 2.0 - t_w * t_h
193+
$$
153194
154195
Final loss will be represented as follow.
155196
156197
$$
157-
loss = \loss_weight_{xy} * loss_{xy} + \loss_weight_{wh} * loss_{wh}
158-
+ \loss_weight_{conf_target} * loss_{conf_target}
159-
+ \loss_weight_{conf_notarget} * loss_{conf_notarget}
160-
+ \loss_weight_{class} * loss_{class}
198+
loss = (loss_{xy} + loss_{wh}) * weight_{box}
199+
+ loss_{conf} + loss_{class}
161200
$$
162201
)DOC");
163202
}
@@ -196,6 +235,8 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker {
196235
op->SetInput("GTBox", Input("GTBox"));
197236
op->SetInput("GTLabel", Input("GTLabel"));
198237
op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
238+
op->SetInput("ObjectnessMask", Output("ObjectnessMask"));
239+
op->SetInput("GTMatchMask", Output("GTMatchMask"));
199240

200241
op->SetAttrMap(Attrs());
201242

0 commit comments

Comments
 (0)