Skip to content

Commit 3df07c2

Browse files
committed
add arm_loc handling for ssd
1 parent afd3eff commit 3df07c2

File tree

1 file changed

+36
-1
lines changed

1 file changed

+36
-1
lines changed

src/caffe/layers/detection_output_layer.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,41 @@ void DetectionOutputLayer<Dtype>::Forward_cpu(
401401
}
402402
}
403403

404+
if (arm_loc_no_concat_ && no_permute_)
405+
{
406+
all_arm_loc_preds.clear();
407+
all_arm_loc_preds.resize(num);
408+
if (share_location_) {
409+
CHECK_EQ(num_loc_classes_, 1);
410+
}
411+
for (int i = 0; i < num; ++i) {
412+
for (int n = 0; n < nbottom_; n++) {
413+
const Dtype *arm_loc_data = bottom[n + 1 + 3*nbottom_]->cpu_data();
414+
LabelBBox &arm_label_bbox = all_arm_loc_preds[i];
415+
int count = bottom[n + 1 + 3*nbottom_]->height() * bottom[n + 1 + 3*nbottom_]->width();
416+
for (int r = 0; r < bottom[n + 1+ 3*nbottom_]->channels() / 4 / num_loc_classes_; ++r) {
417+
int start_idx = r * num_loc_classes_ * 4 * count;
418+
for (int p = 0; p < count; ++p) {
419+
for (int c = 0; c < num_loc_classes_; ++c) {
420+
int label = share_location_ ? -1 : c;
421+
// if (label_bbox.find(label) == label_bbox.end()) {
422+
// label_bbox[label].resize(num_priors_);
423+
//}
424+
NormalizedBBox locbox;
425+
locbox.set_xmin(arm_loc_data[start_idx + c * 4 * count + p]);
426+
locbox.set_ymin(arm_loc_data[start_idx + c * 4 * count + p + count]);
427+
locbox.set_xmax(arm_loc_data[start_idx + c * 4 * count + p + 2 * count]);
428+
locbox.set_ymax(arm_loc_data[start_idx + c * 4 * count + p + 3 * count]);
429+
float locbox_size = BBoxSize(locbox);
430+
locbox.set_size(locbox_size);
431+
arm_label_bbox[label].push_back(locbox);
432+
}
433+
}
434+
}
435+
}
436+
}
437+
}
438+
404439
// Retrieve all confidences.
405440
vector<map<int, vector<float>>> all_conf_scores;
406441
if (conf_concat_ && arm_conf_data != NULL) {
@@ -664,7 +699,7 @@ void DetectionOutputLayer<Dtype>::Forward_cpu(
664699
// Decode all loc predictions to bboxes.
665700
vector<LabelBBox> all_decode_bboxes;
666701
const bool clip_bbox = false;
667-
if (bottom.size() >= 5 && loc_concat_) {
702+
if ((bottom.size() >= 5 && loc_concat_) || arm_loc_no_concat_) {
668703
CasRegDecodeBBoxesAll(all_loc_preds, prior_bboxes, prior_variances, num,
669704
share_location_, num_loc_classes_,
670705
background_label_id_, code_type_,

0 commit comments

Comments
 (0)