9
9
See the License for the specific language governing permissions and
10
10
limitations under the License. */
11
11
12
- #include " paddle/fluid/operators/yolov3_loss_op.h"
12
+ #include " paddle/fluid/operators/detection/ yolov3_loss_op.h"
13
13
#include " paddle/fluid/framework/op_registry.h"
14
14
15
15
namespace paddle {
@@ -29,23 +29,33 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
29
29
" Input(GTLabel) of Yolov3LossOp should not be null." );
30
30
PADDLE_ENFORCE (ctx->HasOutput (" Loss" ),
31
31
" Output(Loss) of Yolov3LossOp should not be null." );
32
+ PADDLE_ENFORCE (
33
+ ctx->HasOutput (" ObjectnessMask" ),
34
+ " Output(ObjectnessMask) of Yolov3LossOp should not be null." );
35
+ PADDLE_ENFORCE (ctx->HasOutput (" GTMatchMask" ),
36
+ " Output(GTMatchMask) of Yolov3LossOp should not be null." );
32
37
33
38
auto dim_x = ctx->GetInputDim (" X" );
34
39
auto dim_gtbox = ctx->GetInputDim (" GTBox" );
35
40
auto dim_gtlabel = ctx->GetInputDim (" GTLabel" );
36
41
auto anchors = ctx->Attrs ().Get <std::vector<int >>(" anchors" );
42
+ int anchor_num = anchors.size () / 2 ;
43
+ auto anchor_mask = ctx->Attrs ().Get <std::vector<int >>(" anchor_mask" );
44
+ int mask_num = anchor_mask.size ();
37
45
auto class_num = ctx->Attrs ().Get <int >(" class_num" );
46
+
38
47
PADDLE_ENFORCE_EQ (dim_x.size (), 4 , " Input(X) should be a 4-D tensor." );
39
48
PADDLE_ENFORCE_EQ (dim_x[2 ], dim_x[3 ],
40
49
" Input(X) dim[3] and dim[4] should be euqal." );
41
- PADDLE_ENFORCE_EQ (dim_x[1 ], anchors.size () / 2 * (5 + class_num),
42
- " Input(X) dim[1] should be equal to (anchor_number * (5 "
43
- " + class_num))." );
50
+ PADDLE_ENFORCE_EQ (
51
+ dim_x[1 ], mask_num * (5 + class_num),
52
+ " Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
53
+ " + class_num))." );
44
54
PADDLE_ENFORCE_EQ (dim_gtbox.size (), 3 ,
45
55
" Input(GTBox) should be a 3-D tensor" );
46
56
PADDLE_ENFORCE_EQ (dim_gtbox[2 ], 4 , " Input(GTBox) dim[2] should be 5" );
47
57
PADDLE_ENFORCE_EQ (dim_gtlabel.size (), 2 ,
48
- " Input(GTBox ) should be a 2-D tensor" );
58
+ " Input(GTLabel ) should be a 2-D tensor" );
49
59
PADDLE_ENFORCE_EQ (dim_gtlabel[0 ], dim_gtbox[0 ],
50
60
" Input(GTBox) and Input(GTLabel) dim[0] should be same" );
51
61
PADDLE_ENFORCE_EQ (dim_gtlabel[1 ], dim_gtbox[1 ],
@@ -54,11 +64,22 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
54
64
" Attr(anchors) length should be greater then 0." );
55
65
PADDLE_ENFORCE_EQ (anchors.size () % 2 , 0 ,
56
66
" Attr(anchors) length should be even integer." );
67
+ for (size_t i = 0 ; i < anchor_mask.size (); i++) {
68
+ PADDLE_ENFORCE_LT (
69
+ anchor_mask[i], anchor_num,
70
+ " Attr(anchor_mask) should not crossover Attr(anchors)." );
71
+ }
57
72
PADDLE_ENFORCE_GT (class_num, 0 ,
58
73
" Attr(class_num) should be an integer greater then 0." );
59
74
60
- std::vector<int64_t > dim_out ({1 });
75
+ std::vector<int64_t > dim_out ({dim_x[ 0 ] });
61
76
ctx->SetOutputDim (" Loss" , framework::make_ddim (dim_out));
77
+
78
+ std::vector<int64_t > dim_obj_mask ({dim_x[0 ], mask_num, dim_x[2 ], dim_x[3 ]});
79
+ ctx->SetOutputDim (" ObjectnessMask" , framework::make_ddim (dim_obj_mask));
80
+
81
+ std::vector<int64_t > dim_gt_match_mask ({dim_gtbox[0 ], dim_gtbox[1 ]});
82
+ ctx->SetOutputDim (" GTMatchMask" , framework::make_ddim (dim_gt_match_mask));
62
83
}
63
84
64
85
protected:
@@ -73,11 +94,11 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
73
94
public:
74
95
void Make () override {
75
96
AddInput (" X" ,
76
- " The input tensor of YOLO v3 loss operator, "
97
+ " The input tensor of YOLOv3 loss operator, "
77
98
" This is a 4-D tensor with shape of [N, C, H, W]."
78
99
" H and W should be same, and the second dimention(C) stores"
79
100
" box locations, confidence score and classification one-hot"
80
- " key of each anchor box" );
101
+ " keys of each anchor box" );
81
102
AddInput (" GTBox" ,
82
103
" The input tensor of ground truth boxes, "
83
104
" This is a 3-D tensor with shape of [N, max_box_num, 5], "
@@ -89,32 +110,39 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
89
110
AddInput (" GTLabel" ,
90
111
" The input tensor of ground truth label, "
91
112
" This is a 2-D tensor with shape of [N, max_box_num], "
92
- " and each element shoudl be an integer to indicate the "
113
+ " and each element should be an integer to indicate the "
93
114
" box class id." );
94
115
AddOutput (" Loss" ,
95
116
" The output yolov3 loss tensor, "
96
- " This is a 1-D tensor with shape of [1]" );
117
+ " This is a 1-D tensor with shape of [N]" );
118
+ AddOutput (" ObjectnessMask" ,
119
+ " This is an intermediate tensor with shape of [N, M, H, W], "
120
+ " M is the number of anchor masks. This parameter caches the "
121
+ " mask for calculate objectness loss in gradient kernel." )
122
+ .AsIntermediate ();
123
+ AddOutput (" GTMatchMask" ,
124
+ " This is an intermediate tensor with shape of [N, B], "
125
+ " B is the max box number of GT boxes. This parameter caches "
126
+ " matched mask index of each GT boxes for gradient calculate." )
127
+ .AsIntermediate ();
97
128
98
129
AddAttr<int >(" class_num" , " The number of classes to predict." );
99
130
AddAttr<std::vector<int >>(" anchors" ,
100
131
" The anchor width and height, "
101
- " it will be parsed pair by pair." );
132
+ " it will be parsed pair by pair." )
133
+ .SetDefault (std::vector<int >{});
134
+ AddAttr<std::vector<int >>(" anchor_mask" ,
135
+ " The mask index of anchors used in "
136
+ " current YOLOv3 loss calculation." )
137
+ .SetDefault (std::vector<int >{});
138
+ AddAttr<int >(" downsample_ratio" ,
139
+ " The downsample ratio from network input to YOLOv3 loss "
140
+ " input, so 32, 16, 8 should be set for the first, second, "
141
+ " and thrid YOLOv3 loss operators." )
142
+ .SetDefault (32 );
102
143
AddAttr<float >(" ignore_thresh" ,
103
- " The ignore threshold to ignore confidence loss." );
104
- AddAttr<float >(" loss_weight_xy" , " The weight of x, y location loss." )
105
- .SetDefault (1.0 );
106
- AddAttr<float >(" loss_weight_wh" , " The weight of w, h location loss." )
107
- .SetDefault (1.0 );
108
- AddAttr<float >(
109
- " loss_weight_conf_target" ,
110
- " The weight of confidence score loss in locations with target object." )
111
- .SetDefault (1.0 );
112
- AddAttr<float >(" loss_weight_conf_notarget" ,
113
- " The weight of confidence score loss in locations without "
114
- " target object." )
115
- .SetDefault (1.0 );
116
- AddAttr<float >(" loss_weight_class" , " The weight of classification loss." )
117
- .SetDefault (1.0 );
144
+ " The ignore threshold to ignore confidence loss." )
145
+ .SetDefault (0.7 );
118
146
AddComment (R"DOC(
119
147
This operator generate yolov3 loss by given predict result and ground
120
148
truth boxes.
@@ -147,17 +175,28 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
147
175
thresh, the confidence score loss of this anchor box will be ignored.
148
176
149
177
Therefore, the yolov3 loss consist of three major parts, box location loss,
150
- confidence score loss, and classification loss. The MSE loss is used for
151
- box location, and binary cross entropy loss is used for confidence score
152
- loss and classification loss.
178
+ confidence score loss, and classification loss. The L2 loss is used for
179
+ box coordinates (w, h), and sigmoid cross entropy loss is used for box
180
+ coordinates (x, y), confidence score loss and classification loss.
181
+
182
+ Each groud truth box find a best matching anchor box in all anchors,
183
+ prediction of this anchor box will incur all three parts of losses, and
184
+ prediction of anchor boxes with no GT box matched will only incur objectness
185
+ loss.
186
+
187
+ In order to trade off box coordinate losses between big boxes and small
188
+ boxes, box coordinate losses will be mutiplied by scale weight, which is
189
+ calculated as follow.
190
+
191
+ $$
192
+ weight_{box} = 2.0 - t_w * t_h
193
+ $$
153
194
154
195
Final loss will be represented as follow.
155
196
156
197
$$
157
- loss = \loss_weight_{xy} * loss_{xy} + \loss_weight_{wh} * loss_{wh}
158
- + \loss_weight_{conf_target} * loss_{conf_target}
159
- + \loss_weight_{conf_notarget} * loss_{conf_notarget}
160
- + \loss_weight_{class} * loss_{class}
198
+ loss = (loss_{xy} + loss_{wh}) * weight_{box}
199
+ + loss_{conf} + loss_{class}
161
200
$$
162
201
)DOC" );
163
202
}
@@ -196,6 +235,8 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker {
196
235
op->SetInput (" GTBox" , Input (" GTBox" ));
197
236
op->SetInput (" GTLabel" , Input (" GTLabel" ));
198
237
op->SetInput (framework::GradVarName (" Loss" ), OutputGrad (" Loss" ));
238
+ op->SetInput (" ObjectnessMask" , Output (" ObjectnessMask" ));
239
+ op->SetInput (" GTMatchMask" , Output (" GTMatchMask" ));
199
240
200
241
op->SetAttrMap (Attrs ());
201
242
0 commit comments