@@ -156,57 +156,39 @@ static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input,
156
156
157
157
template <typename T>
158
158
static inline void CalcLabelLoss (T* loss, const T* input, const int index,
159
- const int label, const T score,
160
- const int class_num, const int stride,
161
- const bool use_label_smooth) {
162
- if (use_label_smooth) {
163
- for (int i = 0 ; i < class_num; i++) {
164
- T pred = input[index + i * stride] < -0.5 ? input[index + i * stride]
165
- : 1.0 / class_num;
166
- loss[0 ] += SCE<T>(pred, (i == label) ? score : 0.0 );
167
- }
168
- } else {
169
- for (int i = 0 ; i < class_num; i++) {
170
- T pred = input[index + i * stride];
171
- loss[0 ] += SCE<T>(pred, (i == label) ? score : 0.0 );
172
- }
159
+ const int label, const int class_num,
160
+ const int stride, const T pos, const T neg) {
161
+ for (int i = 0 ; i < class_num; i++) {
162
+ T pred = input[index + i * stride];
163
+ loss[0 ] += SCE<T>(pred, (i == label) ? pos : neg);
173
164
}
174
165
}
175
166
176
167
template <typename T>
177
168
static inline void CalcLabelLossGrad (T* input_grad, const T loss,
178
169
const T* input, const int index,
179
- const int label, const T score,
180
- const int class_num, const int stride,
181
- const bool use_label_smooth) {
182
- if (use_label_smooth) {
183
- for (int i = 0 ; i < class_num; i++) {
184
- T pred = input[index + i * stride] < -0.5 ? input[index + i * stride]
185
- : 1.0 / class_num;
186
- input_grad[index + i * stride] =
187
- SCEGrad<T>(pred, (i == label) ? score : 0.0 ) * loss;
188
- }
189
- } else {
190
- for (int i = 0 ; i < class_num; i++) {
191
- T pred = input[index + i * stride];
192
- input_grad[index + i * stride] =
193
- SCEGrad<T>(pred, (i == label) ? score : 0.0 ) * loss;
194
- }
170
+ const int label, const int class_num,
171
+ const int stride, const T pos,
172
+ const T neg) {
173
+ for (int i = 0 ; i < class_num; i++) {
174
+ T pred = input[index + i * stride];
175
+ input_grad[index + i * stride] =
176
+ SCEGrad<T>(pred, (i == label) ? pos : neg) * loss;
195
177
}
196
178
}
197
179
198
180
template <typename T>
199
- static inline void CalcObjnessLoss (T* loss, const T* input, const int * objness,
181
+ static inline void CalcObjnessLoss (T* loss, const T* input, const T * objness,
200
182
const int n, const int an_num, const int h,
201
183
const int w, const int stride,
202
184
const int an_stride) {
203
185
for (int i = 0 ; i < n; i++) {
204
186
for (int j = 0 ; j < an_num; j++) {
205
187
for (int k = 0 ; k < h; k++) {
206
188
for (int l = 0 ; l < w; l++) {
207
- int obj = objness[k * w + l];
208
- if (obj >= 0 ) {
209
- loss[i] += SCE<T>(input[k * w + l], static_cast <T>( obj) );
189
+ T obj = objness[k * w + l];
190
+ if (obj > - 0.5 ) {
191
+ loss[i] += SCE<T>(input[k * w + l], obj);
210
192
}
211
193
}
212
194
}
@@ -218,18 +200,17 @@ static inline void CalcObjnessLoss(T* loss, const T* input, const int* objness,
218
200
219
201
template <typename T>
220
202
static inline void CalcObjnessLossGrad (T* input_grad, const T* loss,
221
- const T* input, const int * objness,
203
+ const T* input, const T * objness,
222
204
const int n, const int an_num,
223
205
const int h, const int w,
224
206
const int stride, const int an_stride) {
225
207
for (int i = 0 ; i < n; i++) {
226
208
for (int j = 0 ; j < an_num; j++) {
227
209
for (int k = 0 ; k < h; k++) {
228
210
for (int l = 0 ; l < w; l++) {
229
- int obj = objness[k * w + l];
230
- if (obj >= 0 ) {
231
- input_grad[k * w + l] =
232
- SCEGrad<T>(input[k * w + l], static_cast <T>(obj)) * loss[i];
211
+ T obj = objness[k * w + l];
212
+ if (obj > -0.5 ) {
213
+ input_grad[k * w + l] = SCEGrad<T>(input[k * w + l], obj) * loss[i];
233
214
}
234
215
}
235
216
}
@@ -285,15 +266,22 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
285
266
const int stride = h * w;
286
267
const int an_stride = (class_num + 5 ) * stride;
287
268
269
+ T label_pos = 1.0 ;
270
+ T label_neg = 0.0 ;
271
+ if (use_label_smooth) {
272
+ label_pos = 1.0 - 1.0 / static_cast <T>(class_num);
273
+ label_neg = 1.0 / static_cast <T>(class_num);
274
+ }
275
+
288
276
const T* input_data = input->data <T>();
289
277
const T* gt_box_data = gt_box->data <T>();
290
278
const int * gt_label_data = gt_label->data <int >();
291
279
const T* gt_score_data = gt_score->data <T>();
292
280
T* loss_data = loss->mutable_data <T>({n}, ctx.GetPlace ());
293
281
memset (loss_data, 0 , loss->numel () * sizeof (T));
294
- int * obj_mask_data =
295
- objness_mask->mutable_data <int >({n, mask_num, h, w}, ctx.GetPlace ());
296
- memset (obj_mask_data, 0 , objness_mask->numel () * sizeof (int ));
282
+ T * obj_mask_data =
283
+ objness_mask->mutable_data <T >({n, mask_num, h, w}, ctx.GetPlace ());
284
+ memset (obj_mask_data, 0 , objness_mask->numel () * sizeof (T ));
297
285
int * gt_match_mask_data =
298
286
gt_match_mask->mutable_data <int >({n, b}, ctx.GetPlace ());
299
287
@@ -327,7 +315,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
327
315
328
316
if (best_iou > ignore_thresh) {
329
317
int obj_idx = (i * mask_num + j) * stride + k * w + l;
330
- obj_mask_data[obj_idx] = - 1 ;
318
+ obj_mask_data[obj_idx] = static_cast <T>(- 1.0 ) ;
331
319
}
332
320
// TODO(dengkaipeng): all losses should be calculated if best IoU
333
321
// is bigger then truth thresh should be calculated here, but
@@ -374,15 +362,15 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
374
362
CalcBoxLocationLoss<T>(loss_data + i, input_data, gt, anchors, best_n,
375
363
box_idx, gi, gj, h, input_size, stride);
376
364
365
+ T score = gt_score_data[i * b + t];
377
366
int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi;
378
- obj_mask_data[obj_idx] = 1 ;
367
+ obj_mask_data[obj_idx] = score ;
379
368
380
369
int label = gt_label_data[i * b + t];
381
- T score = gt_score_data[i * b + t];
382
370
int label_idx = GetEntryIndex (i, mask_idx, gj * w + gi, mask_num,
383
371
an_stride, stride, 5 );
384
- CalcLabelLoss<T>(loss_data + i, input_data, label_idx, label, score,
385
- class_num, stride, use_label_smooth );
372
+ CalcLabelLoss<T>(loss_data + i, input_data, label_idx, label,
373
+ class_num, stride, label_pos, label_neg );
386
374
}
387
375
}
388
376
}
@@ -399,7 +387,6 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
399
387
auto * input = ctx.Input <Tensor>(" X" );
400
388
auto * gt_box = ctx.Input <Tensor>(" GTBox" );
401
389
auto * gt_label = ctx.Input <Tensor>(" GTLabel" );
402
- auto * gt_score = ctx.Input <Tensor>(" GTScore" );
403
390
auto * input_grad = ctx.Output <Tensor>(framework::GradVarName (" X" ));
404
391
auto * loss_grad = ctx.Input <Tensor>(framework::GradVarName (" Loss" ));
405
392
auto * objness_mask = ctx.Input <Tensor>(" ObjectnessMask" );
@@ -421,12 +408,18 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
421
408
const int stride = h * w;
422
409
const int an_stride = (class_num + 5 ) * stride;
423
410
411
+ T label_pos = 1.0 ;
412
+ T label_neg = 0.0 ;
413
+ if (use_label_smooth) {
414
+ label_pos = 1.0 - 1.0 / static_cast <T>(class_num);
415
+ label_neg = 1.0 / static_cast <T>(class_num);
416
+ }
417
+
424
418
const T* input_data = input->data <T>();
425
419
const T* gt_box_data = gt_box->data <T>();
426
420
const int * gt_label_data = gt_label->data <int >();
427
- const T* gt_score_data = gt_score->data <T>();
428
421
const T* loss_grad_data = loss_grad->data <T>();
429
- const int * obj_mask_data = objness_mask->data <int >();
422
+ const T * obj_mask_data = objness_mask->data <T >();
430
423
const int * gt_match_mask_data = gt_match_mask->data <int >();
431
424
T* input_grad_data =
432
425
input_grad->mutable_data <T>({n, c, h, w}, ctx.GetPlace ());
@@ -447,12 +440,11 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
447
440
anchor_mask[mask_idx], box_idx, gi, gj, h, input_size, stride);
448
441
449
442
int label = gt_label_data[i * b + t];
450
- T score = gt_score_data[i * b + t];
451
443
int label_idx = GetEntryIndex (i, mask_idx, gj * w + gi, mask_num,
452
444
an_stride, stride, 5 );
453
445
CalcLabelLossGrad<T>(input_grad_data, loss_grad_data[i], input_data,
454
- label_idx, label, score, class_num, stride,
455
- use_label_smooth );
446
+ label_idx, label, class_num, stride, label_pos ,
447
+ label_neg );
456
448
}
457
449
}
458
450
}
0 commit comments