@@ -66,7 +66,6 @@ void AnnotatedDataLayer<Dtype>::DataLayerSetUp(
6666 for (int i = 0 ; i < this ->prefetch_ .size (); ++i) {
6767 this ->prefetch_ [i]->data_ .Reshape (top_shape);
6868 }
69- // LOG(INFO) << "output data size: " << top[0]->num() << ","
7069 LOG_IF (INFO, Caffe::root_solver ())
7170 << " output data size: " << top[0 ]->num () << " ,"
7271 << top[0 ]->channels () << " ," << top[0 ]->height () << " ,"
@@ -75,15 +74,13 @@ void AnnotatedDataLayer<Dtype>::DataLayerSetUp(
7574 if (this ->output_labels_ ) {
7675 has_anno_type_ = anno_datum.has_type () || anno_data_param.has_anno_type ();
7776 if (transform_param.has_caffe_yolo ()) {
77+ this ->box_label_ = true ;
7878 vector<int > label_shape (1 , batch_size);
7979 if (param.side_size () > 0 ) {
8080 for (int i = 0 ; i < param.side_size (); ++i) {
8181 sides_.push_back (param.side (i));
8282 }
8383 }
84- if (sides_.size () == 0 ) {
85- sides_.push_back (7 );
86- }
8784 CHECK_EQ (sides_.size (), top.size () - 1 ) << " side num not equal to top size" ;
8885 if (has_anno_type_) {
8986 anno_type_ = anno_datum.type ();
@@ -93,23 +90,28 @@ void AnnotatedDataLayer<Dtype>::DataLayerSetUp(
9390 LOG (WARNING) << " type stored in AnnotatedDatum is shadowed." ;
9491 anno_type_ = anno_data_param.anno_type ();
9592 }
96- if (anno_type_ == AnnotatedDatum_AnnotationType_BBOX) {
93+ for (int i = 0 ; i < this ->prefetch_ .size (); ++i) {
94+ this ->prefetch_ [i]->multi_label_ .clear ();
95+ }
96+ if (anno_type_ == AnnotatedDatum_AnnotationType_BBOX) {
9797 // Yolo label format
9898 for (int i = 0 ; i < sides_.size (); ++i) {
9999 vector<int > label_shape (1 , batch_size);
100100 int label_size = sides_[i] * sides_[i] * (1 + 1 + 1 + 4 );
101101 label_shape.push_back (label_size);
102102 top[i+1 ]->Reshape (label_shape);
103+ for (int j = 0 ; j < this ->prefetch_ .size (); ++j) {
104+ shared_ptr<Blob<Dtype> > tmp_blob;
105+ tmp_blob.reset (new Blob<Dtype>(label_shape));
106+ this ->prefetch_ [j]->multi_label_ .push_back (tmp_blob);
107+ }
103108 }
104109 } else {
105110 LOG (FATAL) << " Unknown annotation type." ;
106111 }
107112 } else {
108113 label_shape[0 ] = batch_size;
109114 }
110- for (int i = 0 ; i < this ->prefetch_ .size (); ++i) {
111- this ->prefetch_ [i]->label_ .Reshape (label_shape);
112- }
113115 }
114116 else {
115117 vector<int > label_shape (4 , 1 );
@@ -214,7 +216,10 @@ void AnnotatedDataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {
214216 // Store transformed annotation.
215217 map<int , vector<AnnotationGroup> > all_anno;
216218 int num_bboxes = 0 ;
217-
219+ vector<Dtype*> top_label;
220+ for (int i = 0 ; i < sides_.size (); ++i) {
221+ top_label.push_back (batch->multi_label_ [i]->mutable_cpu_data ());
222+ }
218223 for (int item_id = 0 ; item_id < batch_size; ++item_id) {
219224 timer.Start ();
220225 // get a anno_datum
@@ -333,47 +338,47 @@ void AnnotatedDataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {
333338 for (int i = 0 ; i < sides_.size (); ++i) {
334339 side = sides_[i];
335340 count = sides_[i] * sides_[i] * (1 + 1 + 1 + 4 );
336- }
337- label_shape[0 ] = batch_size ;
338- label_shape[ 1 ] = count ;
339- batch->label_ . Reshape (label_shape );
340- Dtype* top_label = batch-> label_ . mutable_cpu_data () ;
341- const vector<AnnotationGroup>& anno_vec = all_anno[ item_id] ;
342- label_offset = count * item_id ;
343- top_label = top_label + label_offset ;
344- int locations = pow (side, 2 ) ;
345- CHECK_EQ (count, locations * 7 ) << " side and count not match " ;
346- // difficult
347- caffe_set (locations, Dtype ( 0 ), top_label);
348- // isobj
349- caffe_set (locations, Dtype ( 0 ), top_label + locations);
350- // class label
351- caffe_set (locations, Dtype (- 1 ), top_label + locations * 2 );
352- // bounding box
353- caffe_set (locations* 4 , Dtype ( 0 ), top_label + locations * 3 );
354- for ( int g = 0 ; g < anno_vec. size (); ++g) {
355- const AnnotationGroup& anno_group = anno_vec[g];
356- for ( int a = 0 ; a < anno_group.annotation_size (); ++a) {
357- const Annotation& anno = anno_group. annotation (a );
358- const NormalizedBBox& bbox = anno. bbox ();
359- float class_label = anno_group. group_label ();
360- float x = bbox.x_center ();
361- float y = bbox. y_center ( );
362- int x_index = floor (x * side);
363- int y_index = floor (y * side);
364- x_index = std::min (x_index , side - 1 );
365- y_index = std::min (y_index, side - 1 ) ;
366- int dif_index = side * y_index + x_index ;
367- int obj_index = locations + dif_index;
368- int class_index = locations * 2 + dif_index;
369- int cor_index = locations * 3 + dif_index * 4 ;
370- top_label[dif_index] = bbox. difficult () ;
371- top_label[obj_index] = 1 ;
372- top_label[class_index] = class_label ;
373- top_label[cor_index + 0 ] = bbox.x_center ();
374- top_label[cor_index + 1 ] = bbox.y_center ();
375- top_label[cor_index + 2 ] = bbox.width ();
376- top_label[cor_index + 3 ] = bbox. height ();
341+ label_shape[ 0 ] = batch_size;
342+ label_shape[1 ] = count ;
343+ batch-> multi_label_ [i]-> Reshape (label_shape) ;
344+ top_label[i] = batch->multi_label_ [i]-> mutable_cpu_data ( );
345+ const vector<AnnotationGroup>& anno_vec = all_anno[item_id] ;
346+ label_offset = count * item_id;
347+ top_label[i] = top_label[i] + label_offset ;
348+ int locations = pow (side, 2 ) ;
349+ CHECK_EQ (count, locations * 7 ) << " side and count not match " ;
350+ // difficult
351+ caffe_set (locations, Dtype ( 0 ), top_label[i]);
352+ // isobj
353+ caffe_set (locations, Dtype ( 0 ), top_label[i] + locations);
354+ // class label
355+ caffe_set (locations, Dtype (- 1 ), top_label[i] + locations * 2 );
356+ // bounding box
357+ caffe_set (locations* 4 , Dtype ( 0 ), top_label[i] + locations * 3 );
358+ for ( int g = 0 ; g < anno_vec. size (); ++g) {
359+ const AnnotationGroup& anno_group = anno_vec[g];
360+ for ( int a = 0 ; a < anno_group. annotation_size (); ++a) {
361+ const Annotation& anno = anno_group.annotation (a);
362+ const NormalizedBBox& bbox = anno. bbox ( );
363+ float class_label = anno_group. group_label ();
364+ float x = bbox. x_center ();
365+ float y = bbox.y_center ();
366+ int x_index = floor (x * side );
367+ int y_index = floor (y * side);
368+ x_index = std::min (x_index, side - 1 );
369+ y_index = std::min (y_index , side - 1 );
370+ int dif_index = side * y_index + x_index ;
371+ int obj_index = locations + dif_index ;
372+ int class_index = locations * 2 + dif_index;
373+ int cor_index = locations * 3 + dif_index * 4 ;
374+ top_label[i][ dif_index] = bbox. difficult () ;
375+ top_label[i][obj_index] = 1 ;
376+ top_label[i][class_index] = class_label ;
377+ top_label[i][cor_index + 0 ] = bbox. x_center () ;
378+ top_label[i][ cor_index + 1 ] = bbox.y_center ();
379+ top_label[i][ cor_index + 2 ] = bbox.width ();
380+ top_label[i][ cor_index + 3 ] = bbox.height ();
381+ }
377382 }
378383 }
379384 }
@@ -401,7 +406,6 @@ void AnnotatedDataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {
401406 // reader_.free().push(const_cast<AnnotatedDatum*>(&anno_datum));
402407 Next ();
403408 }
404-
405409 // Store "rich" annotation if needed.
406410 if (this ->output_labels_ && has_anno_type_) {
407411 if (anno_type_ == AnnotatedDatum_AnnotationType_BBOX) {
0 commit comments