@@ -478,67 +478,6 @@ def eval(self):
478478 return avg_distance , avg_instance_error
479479
480480
481- class DetectionMAP (MetricBase ):
482- """
483- Calculate the detection mean average precision (mAP).
484- mAP is the metric to measure the accuracy of object detectors
485- like Faster R-CNN, SSD, etc.
486- It is the average of the maximum precisions at different recall values.
487- Please get more information from the following articles:
488- https://sanchom.wordpress.com/tag/average-precision/
489-
490- https://arxiv.org/abs/1512.02325
491-
492- The general steps are as follows:
493-
494- 1. calculate the true positive and false positive according to the input
495- of detection and labels.
496- 2. calculate mAP value, support two versions: '11 point' and 'integral'.
497-
498- Examples:
499- .. code-block:: python
500-
501- pred = fluid.layers.fc(input=data, size=1000, act="tanh")
502- batch_map = layers.detection_map(
503- input,
504- label,
505- class_num,
506- background_label,
507- overlap_threshold=overlap_threshold,
508- evaluate_difficult=evaluate_difficult,
509- ap_version=ap_version)
510- metric = fluid.metrics.DetectionMAP()
511- for data in train_reader():
512- loss, preds, labels = exe.run(fetch_list=[cost, batch_map])
513- batch_size = data[0]
514- metric.update(value=batch_map, weight=batch_size)
515- numpy_map = metric.eval()
516- """
517-
518- def __init__ (self , name = None ):
519- super (DetectionMAP , self ).__init__ (name )
520- # the current map value
521- self .value = .0
522- self .weight = .0
523-
524- def update (self , value , weight ):
525- if not _is_number_or_matrix_ (value ):
526- raise ValueError (
527- "The 'value' must be a number(int, float) or a numpy ndarray." )
528- if not _is_number_ (weight ):
529- raise ValueError ("The 'weight' must be a number(int, float)." )
530- self .value += value
531- self .weight += weight
532-
533- def eval (self ):
534- if self .weight == 0 :
535- raise ValueError (
536- "There is no data in DetectionMAP Metrics. "
537- "Please check layers.detection_map output has added to DetectionMAP."
538- )
539- return self .value / self .weight
540-
541-
542481class Auc (MetricBase ):
543482 """
544483 Auc metric adapts to the binary classification.
@@ -616,3 +555,185 @@ def eval(self):
616555 idx -= 1
617556
618557 return auc / tot_pos / tot_neg if tot_pos > 0.0 and tot_neg > 0.0 else 0.0
558+
559+
560+ class DetectionMAP (object ):
561+ """
562+ Calculate the detection mean average precision (mAP).
563+
564+ The general steps are as follows:
565+ 1. calculate the true positive and false positive according to the input
566+ of detection and labels.
567+ 2. calculate mAP value, support two versions: '11 point' and 'integral'.
568+
569+ Please get more information from the following articles:
570+ https://sanchom.wordpress.com/tag/average-precision/
571+ https://arxiv.org/abs/1512.02325
572+
573+ Args:
574+ input (Variable): The detection results, which is a LoDTensor with shape
575+ [M, 6]. The layout is [label, confidence, xmin, ymin, xmax, ymax].
576+ gt_label (Variable): The ground truth label index, which is a LoDTensor
577+ with shape [N, 1].
578+ gt_box (Variable): The ground truth bounding box (bbox), which is a
579+ LoDTensor with shape [N, 4]. The layout is [xmin, ymin, xmax, ymax].
580+ gt_difficult (Variable|None): Whether this ground truth is a difficult
581+ bounding bbox, which can be a LoDTensor [N, 1] or not set. If None,
582+ it means all the ground truth labels are not difficult bbox.
583+ class_num (int): The class number.
584+ background_label (int): The index of background label, the background
585+ label will be ignored. If set to -1, then all categories will be
586+ considered, 0 by defalut.
587+ overlap_threshold (float): The threshold for deciding true/false
588+ positive, 0.5 by defalut.
589+ evaluate_difficult (bool): Whether to consider difficult ground truth
590+ for evaluation, True by defalut. This argument does not work when
591+ gt_difficult is None.
592+ ap_version (string): The average precision calculation ways, it must be
593+ 'integral' or '11point'. Please check
594+ https://sanchom.wordpress.com/tag/average-precision/ for details.
595+ - 11point: the 11-point interpolated average precision.
596+ - integral: the natural integral of the precision-recall curve.
597+
598+ Examples:
599+ .. code-block:: python
600+
601+ exe = fluid.executor(place)
602+ map_evaluator = fluid.Evaluator.DetectionMAP(input,
603+ gt_label, gt_box, gt_difficult)
604+ cur_map, accum_map = map_evaluator.get_map_var()
605+ fetch = [cost, cur_map, accum_map]
606+ for epoch in PASS_NUM:
607+ map_evaluator.reset(exe)
608+ for data in batches:
609+ loss, cur_map_v, accum_map_v = exe.run(fetch_list=fetch)
610+
611+ In the above example:
612+
613+ 'cur_map_v' is the mAP of current mini-batch.
614+ 'accum_map_v' is the accumulative mAP of one pass.
615+ """
616+
617+ def __init__ (self ,
618+ input ,
619+ gt_label ,
620+ gt_box ,
621+ gt_difficult = None ,
622+ class_num = None ,
623+ background_label = 0 ,
624+ overlap_threshold = 0.5 ,
625+ evaluate_difficult = True ,
626+ ap_version = 'integral' ):
627+ from . import layers
628+ from .layer_helper import LayerHelper
629+ from .initializer import Constant
630+
631+ self .helper = LayerHelper ('map_eval' )
632+ gt_label = layers .cast (x = gt_label , dtype = gt_box .dtype )
633+ if gt_difficult :
634+ gt_difficult = layers .cast (x = gt_difficult , dtype = gt_box .dtype )
635+ label = layers .concat ([gt_label , gt_difficult , gt_box ], axis = 1 )
636+ else :
637+ label = layers .concat ([gt_label , gt_box ], axis = 1 )
638+
639+ # calculate mean average precision (mAP) of current mini-batch
640+ map = layers .detection_map (
641+ input ,
642+ label ,
643+ class_num ,
644+ background_label ,
645+ overlap_threshold = overlap_threshold ,
646+ evaluate_difficult = evaluate_difficult ,
647+ ap_version = ap_version )
648+
649+ states = []
650+ states .append (
651+ self ._create_state (
652+ dtype = 'int32' , shape = None , suffix = 'accum_pos_count' ))
653+ states .append (
654+ self ._create_state (
655+ dtype = 'float32' , shape = None , suffix = 'accum_true_pos' ))
656+ states .append (
657+ self ._create_state (
658+ dtype = 'float32' , shape = None , suffix = 'accum_false_pos' ))
659+ var = self ._create_state (dtype = 'int32' , shape = [1 ], suffix = 'has_state' )
660+ self .helper .set_variable_initializer (
661+ var , initializer = Constant (value = int (0 )))
662+ self .has_state = var
663+
664+ # calculate accumulative mAP
665+ accum_map = layers .detection_map (
666+ input ,
667+ label ,
668+ class_num ,
669+ background_label ,
670+ overlap_threshold = overlap_threshold ,
671+ evaluate_difficult = evaluate_difficult ,
672+ has_state = self .has_state ,
673+ input_states = states ,
674+ out_states = states ,
675+ ap_version = ap_version )
676+
677+ layers .fill_constant (
678+ shape = self .has_state .shape ,
679+ value = 1 ,
680+ dtype = self .has_state .dtype ,
681+ out = self .has_state )
682+
683+ self .cur_map = map
684+ self .accum_map = accum_map
685+
686+ def _create_state (self , suffix , dtype , shape ):
687+ """
688+ Create state variable.
689+ Args:
690+ suffix(str): the state suffix.
691+ dtype(str|core.VarDesc.VarType): the state data type
692+ shape(tuple|list): the shape of state
693+ Returns: State variable
694+ """
695+ from . import unique_name
696+ state = self .helper .create_variable (
697+ name = "_" .join ([unique_name .generate (self .helper .name ), suffix ]),
698+ persistable = True ,
699+ dtype = dtype ,
700+ shape = shape )
701+ return state
702+
703+ def get_map_var (self ):
704+ """
705+ Returns: mAP variable of current mini-batch and
706+ accumulative mAP variable cross mini-batches.
707+ """
708+ return self .cur_map , self .accum_map
709+
710+ def reset (self , executor , reset_program = None ):
711+ """
712+ reset metric states at the begin of each pass/user specified batch.
713+
714+ Args:
715+ executor(Executor|ParallelExecutor): a executor for executing
716+ the reset_program.
717+ reset_program(Program|None): a single Program for reset process.
718+ If None, will create a Program.
719+ """
720+ from .framework import Program , Variable , program_guard
721+ from . import layers
722+
723+ def _clone_var_ (block , var ):
724+ assert isinstance (var , Variable )
725+ return block .create_var (
726+ name = var .name ,
727+ shape = var .shape ,
728+ dtype = var .dtype ,
729+ type = var .type ,
730+ lod_level = var .lod_level ,
731+ persistable = var .persistable )
732+
733+ if reset_program is None :
734+ reset_program = Program ()
735+ with program_guard (main_program = reset_program ):
736+ var = _clone_var_ (reset_program .current_block (), self .has_state )
737+ layers .fill_constant (
738+ shape = var .shape , value = 0 , dtype = var .dtype , out = var )
739+ executor .run (reset_program )
0 commit comments