|
13 | 13 | # limitations under the License.
|
14 | 14 | """
|
15 | 15 | Fluid Metrics
|
16 |
| -
|
17 |
| -The metrics are accomplished via Python natively. |
18 | 16 | """
|
19 | 17 |
|
20 | 18 | from __future__ import print_function
|
|
24 | 22 | import warnings
|
25 | 23 | import six
|
26 | 24 |
|
| 25 | +from .layer_helper import LayerHelper |
| 26 | +from .initializer import Constant |
| 27 | +from . import unique_name |
| 28 | +from .framework import Program, Variable, program_guard |
| 29 | +from . import layers |
| 30 | + |
27 | 31 | __all__ = [
|
28 | 32 | 'MetricBase',
|
29 | 33 | 'CompositeMetric',
|
@@ -478,67 +482,6 @@ def eval(self):
|
478 | 482 | return avg_distance, avg_instance_error
|
479 | 483 |
|
480 | 484 |
|
481 |
| -class DetectionMAP(MetricBase): |
482 |
| - """ |
483 |
| - Calculate the detection mean average precision (mAP). |
484 |
| - mAP is the metric to measure the accuracy of object detectors |
485 |
| - like Faster R-CNN, SSD, etc. |
486 |
| - It is the average of the maximum precisions at different recall values. |
487 |
| - Please get more information from the following articles: |
488 |
| - https://sanchom.wordpress.com/tag/average-precision/ |
489 |
| -
|
490 |
| - https://arxiv.org/abs/1512.02325 |
491 |
| -
|
492 |
| - The general steps are as follows: |
493 |
| -
|
494 |
| - 1. calculate the true positive and false positive according to the input |
495 |
| - of detection and labels. |
496 |
| - 2. calculate mAP value, support two versions: '11 point' and 'integral'. |
497 |
| -
|
498 |
| - Examples: |
499 |
| - .. code-block:: python |
500 |
| -
|
501 |
| - pred = fluid.layers.fc(input=data, size=1000, act="tanh") |
502 |
| - batch_map = layers.detection_map( |
503 |
| - input, |
504 |
| - label, |
505 |
| - class_num, |
506 |
| - background_label, |
507 |
| - overlap_threshold=overlap_threshold, |
508 |
| - evaluate_difficult=evaluate_difficult, |
509 |
| - ap_version=ap_version) |
510 |
| - metric = fluid.metrics.DetectionMAP() |
511 |
| - for data in train_reader(): |
512 |
| - loss, preds, labels = exe.run(fetch_list=[cost, batch_map]) |
513 |
| - batch_size = data[0] |
514 |
| - metric.update(value=batch_map, weight=batch_size) |
515 |
| - numpy_map = metric.eval() |
516 |
| - """ |
517 |
| - |
518 |
| - def __init__(self, name=None): |
519 |
| - super(DetectionMAP, self).__init__(name) |
520 |
| - # the current map value |
521 |
| - self.value = .0 |
522 |
| - self.weight = .0 |
523 |
| - |
524 |
| - def update(self, value, weight): |
525 |
| - if not _is_number_or_matrix_(value): |
526 |
| - raise ValueError( |
527 |
| - "The 'value' must be a number(int, float) or a numpy ndarray.") |
528 |
| - if not _is_number_(weight): |
529 |
| - raise ValueError("The 'weight' must be a number(int, float).") |
530 |
| - self.value += value |
531 |
| - self.weight += weight |
532 |
| - |
533 |
| - def eval(self): |
534 |
| - if self.weight == 0: |
535 |
| - raise ValueError( |
536 |
| - "There is no data in DetectionMAP Metrics. " |
537 |
| - "Please check layers.detection_map output has added to DetectionMAP." |
538 |
| - ) |
539 |
| - return self.value / self.weight |
540 |
| - |
541 |
| - |
542 | 485 | class Auc(MetricBase):
|
543 | 486 | """
|
544 | 487 | Auc metric adapts to the binary classification.
|
@@ -616,3 +559,179 @@ def eval(self):
|
616 | 559 | idx -= 1
|
617 | 560 |
|
618 | 561 | return auc / tot_pos / tot_neg if tot_pos > 0.0 and tot_neg > 0.0 else 0.0
|
| 562 | + |
| 563 | + |
| 564 | +class DetectionMAP(object): |
| 565 | + """ |
| 566 | + Calculate the detection mean average precision (mAP). |
| 567 | +
|
| 568 | + The general steps are as follows: |
| 569 | + 1. calculate the true positive and false positive according to the input |
| 570 | + of detection and labels. |
| 571 | + 2. calculate mAP value, support two versions: '11 point' and 'integral'. |
| 572 | +
|
| 573 | + Please get more information from the following articles: |
| 574 | + https://sanchom.wordpress.com/tag/average-precision/ |
| 575 | + https://arxiv.org/abs/1512.02325 |
| 576 | +
|
| 577 | + Args: |
| 578 | + input (Variable): The detection results, which is a LoDTensor with shape |
| 579 | + [M, 6]. The layout is [label, confidence, xmin, ymin, xmax, ymax]. |
| 580 | + gt_label (Variable): The ground truth label index, which is a LoDTensor |
| 581 | + with shape [N, 1]. |
| 582 | + gt_box (Variable): The ground truth bounding box (bbox), which is a |
| 583 | + LoDTensor with shape [N, 4]. The layout is [xmin, ymin, xmax, ymax]. |
| 584 | + gt_difficult (Variable|None): Whether this ground truth is a difficult |
| 585 | + bounding bbox, which can be a LoDTensor [N, 1] or not set. If None, |
| 586 | + it means all the ground truth labels are not difficult bbox. |
| 587 | + class_num (int): The class number. |
| 588 | + background_label (int): The index of background label, the background |
| 589 | + label will be ignored. If set to -1, then all categories will be |
| 590 | + considered, 0 by defalut. |
| 591 | + overlap_threshold (float): The threshold for deciding true/false |
| 592 | + positive, 0.5 by defalut. |
| 593 | + evaluate_difficult (bool): Whether to consider difficult ground truth |
| 594 | + for evaluation, True by defalut. This argument does not work when |
| 595 | + gt_difficult is None. |
| 596 | + ap_version (string): The average precision calculation ways, it must be |
| 597 | + 'integral' or '11point'. Please check |
| 598 | + https://sanchom.wordpress.com/tag/average-precision/ for details. |
| 599 | + - 11point: the 11-point interpolated average precision. |
| 600 | + - integral: the natural integral of the precision-recall curve. |
| 601 | +
|
| 602 | + Examples: |
| 603 | + .. code-block:: python |
| 604 | +
|
| 605 | + exe = fluid.Executor(place) |
| 606 | + map_evaluator = fluid.Evaluator.DetectionMAP(input, |
| 607 | + gt_label, gt_box, gt_difficult) |
| 608 | + cur_map, accum_map = map_evaluator.get_map_var() |
| 609 | + fetch = [cost, cur_map, accum_map] |
| 610 | + for epoch in PASS_NUM: |
| 611 | + map_evaluator.reset(exe) |
| 612 | + for data in batches: |
| 613 | + loss, cur_map_v, accum_map_v = exe.run(fetch_list=fetch) |
| 614 | +
|
| 615 | + In the above example: |
| 616 | +
|
| 617 | + 'cur_map_v' is the mAP of current mini-batch. |
| 618 | + 'accum_map_v' is the accumulative mAP of one pass. |
| 619 | + """ |
| 620 | + |
| 621 | + def __init__(self, |
| 622 | + input, |
| 623 | + gt_label, |
| 624 | + gt_box, |
| 625 | + gt_difficult=None, |
| 626 | + class_num=None, |
| 627 | + background_label=0, |
| 628 | + overlap_threshold=0.5, |
| 629 | + evaluate_difficult=True, |
| 630 | + ap_version='integral'): |
| 631 | + |
| 632 | + self.helper = LayerHelper('map_eval') |
| 633 | + gt_label = layers.cast(x=gt_label, dtype=gt_box.dtype) |
| 634 | + if gt_difficult: |
| 635 | + gt_difficult = layers.cast(x=gt_difficult, dtype=gt_box.dtype) |
| 636 | + label = layers.concat([gt_label, gt_difficult, gt_box], axis=1) |
| 637 | + else: |
| 638 | + label = layers.concat([gt_label, gt_box], axis=1) |
| 639 | + |
| 640 | + # calculate mean average precision (mAP) of current mini-batch |
| 641 | + map = layers.detection_map( |
| 642 | + input, |
| 643 | + label, |
| 644 | + class_num, |
| 645 | + background_label, |
| 646 | + overlap_threshold=overlap_threshold, |
| 647 | + evaluate_difficult=evaluate_difficult, |
| 648 | + ap_version=ap_version) |
| 649 | + |
| 650 | + states = [] |
| 651 | + states.append( |
| 652 | + self._create_state( |
| 653 | + dtype='int32', shape=None, suffix='accum_pos_count')) |
| 654 | + states.append( |
| 655 | + self._create_state( |
| 656 | + dtype='float32', shape=None, suffix='accum_true_pos')) |
| 657 | + states.append( |
| 658 | + self._create_state( |
| 659 | + dtype='float32', shape=None, suffix='accum_false_pos')) |
| 660 | + var = self._create_state(dtype='int32', shape=[1], suffix='has_state') |
| 661 | + self.helper.set_variable_initializer( |
| 662 | + var, initializer=Constant(value=int(0))) |
| 663 | + self.has_state = var |
| 664 | + |
| 665 | + # calculate accumulative mAP |
| 666 | + accum_map = layers.detection_map( |
| 667 | + input, |
| 668 | + label, |
| 669 | + class_num, |
| 670 | + background_label, |
| 671 | + overlap_threshold=overlap_threshold, |
| 672 | + evaluate_difficult=evaluate_difficult, |
| 673 | + has_state=self.has_state, |
| 674 | + input_states=states, |
| 675 | + out_states=states, |
| 676 | + ap_version=ap_version) |
| 677 | + |
| 678 | + layers.fill_constant( |
| 679 | + shape=self.has_state.shape, |
| 680 | + value=1, |
| 681 | + dtype=self.has_state.dtype, |
| 682 | + out=self.has_state) |
| 683 | + |
| 684 | + self.cur_map = map |
| 685 | + self.accum_map = accum_map |
| 686 | + |
| 687 | + def _create_state(self, suffix, dtype, shape): |
| 688 | + """ |
| 689 | + Create state variable. |
| 690 | + Args: |
| 691 | + suffix(str): the state suffix. |
| 692 | + dtype(str|core.VarDesc.VarType): the state data type |
| 693 | + shape(tuple|list): the shape of state |
| 694 | + Returns: State variable |
| 695 | + """ |
| 696 | + state = self.helper.create_variable( |
| 697 | + name="_".join([unique_name.generate(self.helper.name), suffix]), |
| 698 | + persistable=True, |
| 699 | + dtype=dtype, |
| 700 | + shape=shape) |
| 701 | + return state |
| 702 | + |
| 703 | + def get_map_var(self): |
| 704 | + """ |
| 705 | + Returns: mAP variable of current mini-batch and |
| 706 | + accumulative mAP variable cross mini-batches. |
| 707 | + """ |
| 708 | + return self.cur_map, self.accum_map |
| 709 | + |
| 710 | + def reset(self, executor, reset_program=None): |
| 711 | + """ |
| 712 | + Reset metric states at the begin of each pass/user specified batch. |
| 713 | +
|
| 714 | + Args: |
| 715 | + executor(Executor): a executor for executing |
| 716 | + the reset_program. |
| 717 | + reset_program(Program|None): a single Program for reset process. |
| 718 | + If None, will create a Program. |
| 719 | + """ |
| 720 | + |
| 721 | + def _clone_var_(block, var): |
| 722 | + assert isinstance(var, Variable) |
| 723 | + return block.create_var( |
| 724 | + name=var.name, |
| 725 | + shape=var.shape, |
| 726 | + dtype=var.dtype, |
| 727 | + type=var.type, |
| 728 | + lod_level=var.lod_level, |
| 729 | + persistable=var.persistable) |
| 730 | + |
| 731 | + if reset_program is None: |
| 732 | + reset_program = Program() |
| 733 | + with program_guard(main_program=reset_program): |
| 734 | + var = _clone_var_(reset_program.current_block(), self.has_state) |
| 735 | + layers.fill_constant( |
| 736 | + shape=var.shape, value=0, dtype=var.dtype, out=var) |
| 737 | + executor.run(reset_program) |
0 commit comments