Skip to content

Commit 316eb3e

Browse files
author
Yibing Liu
committed
Add doc for layers.auc
1 parent cafdeb0 commit 316eb3e

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

python/paddle/fluid/layers/metric.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,43 @@ def accuracy(input, label, k=1, correct=None, total=None):
5353

5454

5555
def auc(input, label, curve='ROC', num_thresholds=200):
56+
"""
57+
**Area Under The Curve (AUC) Layer**
58+
59+
This implementation computes the AUC according to forward output and label.
60+
It is used very widely in binary classification evaluation.
61+
62+
As a note: If input label contains values other than 0 and 1, it will be
63+
cast to bool. You can find the relevant definitions `here
64+
<https://en.wikipedia.org/wiki/Receiver_operating_characteristic
65+
#Area_under_the_curve>`_.
66+
67+
There are two types of possible curves:
68+
1. ROC: Receiver operating characteristic
69+
2. PR: Precision Recall
70+
71+
Args:
72+
input(Variable): A floating-point 2D Variable, values are in the range
73+
[0, 1]. Each row is sorted in descending order. This
74+
input should be the output of topk. Typically, this
75+
Variable indicates the probability of each label.
76+
label(Variable): A 2D int Variable indicating the label of the training
77+
data. The height is batch size and width is always 1.
78+
curve(str): Curve type, can be 'ROC' or 'PR'. Default 'ROC'.
79+
num_thresholds(int): The number of thresholds to use when discretizing
80+
the roc curve. Default 200.
81+
82+
Returns:
83+
Variable: A scalar representing the current AUC.
84+
85+
Examples:
86+
.. code-block:: python
87+
88+
# network is a binary classification model and label the ground truth
89+
prediction = network(image, is_infer=True)
90+
auc_out=fluid.layers.auc(input=prediction, label=label)
91+
"""
92+
5693
warnings.warn(
5794
"This interface not recommended, fluid.layers.auc compute the auc at every minibatch, \
5895
but can not aggregate them and get the pass AUC, because pass \

0 commit comments

Comments
 (0)