Skip to content

Commit ab63cd0

Browse files
authored
Refactor/update docstrings for metrics (#52)
1 parent d7d3e67 commit ab63cd0

File tree

6 files changed

+457
-94
lines changed

6 files changed

+457
-94
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ instance/
7070

7171
# Sphinx documentation
7272
docs/build/
73-
docs/source/apiref
73+
docs/source/autoapi
7474
docs/source/tutorials
7575

7676
# PyBuilder

autointent/metrics/prediction.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,22 @@ def __call__(self, y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> floa
2929

3030

3131
def prediction_accuracy(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
32-
"""
32+
r"""
3333
Calculate prediction accuracy. Supports both multiclass and multilabel.
3434
35+
The prediction accuracy is calculated as:
36+
37+
.. math::
38+
39+
\text{Accuracy} = \frac{\sum_{i=1}^N \mathbb{1}(y_{\text{true},i} = y_{\text{pred},i})}{N}
40+
41+
where:
42+
- :math:`N` is the total number of samples,
43+
- :math:`y_{\text{true},i}` is the true label for the :math:`i`-th sample,
44+
- :math:`y_{\text{pred},i}` is the predicted label for the :math:`i`-th sample,
45+
- :math:`\mathbb{1}(\text{condition})` is the indicator function that equals 1 if the condition
46+
is true and 0 otherwise.
47+
3548
:param y_true: True values of labels
3649
:param y_pred: Predicted values of labels
3750
:return: Score of the prediction accuracy
@@ -41,9 +54,22 @@ def prediction_accuracy(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) ->
4154

4255

4356
def _prediction_roc_auc_multiclass(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
44-
"""
57+
r"""
4558
Calculate roc_auc for multiclass.
4659
60+
The ROC AUC score for multiclass is calculated as the mean ROC AUC score
61+
across all classes, where each class is treated as a binary classification task
62+
(one-vs-rest).
63+
64+
.. math::
65+
66+
\text{ROC AUC}_{\text{multiclass}} = \frac{1}{K} \sum_{k=1}^K \text{ROC AUC}_k
67+
68+
where:
69+
- :math:`K` is the number of classes,
70+
- :math:`\text{ROC AUC}_k` is the ROC AUC score for the :math:`k`-th class,
71+
calculated by treating it as a binary classification problem (class :math:`k` vs rest).
72+
4773
:param y_true: True values of labels
4874
:param y_pred: Predicted values of labels
4975
:return: Score of the prediction roc_auc
@@ -61,9 +87,13 @@ def _prediction_roc_auc_multiclass(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VAL
6187

6288

6389
def _prediction_roc_auc_multilabel(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
64-
"""
90+
r"""
6591
Calculate roc_auc for multilabel.
6692
93+
This function internally uses :func:`sklearn.metrics.roc_auc_score` with `average=macro`. Refer to the
94+
`scikit-learn documentation <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html>`__
95+
for more details.
96+
6797
:param y_true: True values of labels
6898
:param y_pred: Predicted values of labels
6999
:return: Score of the prediction accuracy
@@ -72,12 +102,16 @@ def _prediction_roc_auc_multilabel(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VAL
72102

73103

74104
def prediction_roc_auc(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
75-
"""
76-
Calculate roc_auc for multiclass and multilabel.
105+
r"""
106+
Calculate ROC AUC for multiclass and multilabel classification.
107+
108+
The ROC AUC measures the ability of a model to distinguish between classes.
109+
It is calculated as the area under the curve of the true positive rate (TPR)
110+
against the false positive rate (FPR) at various threshold settings.
77111
78112
:param y_true: True values of labels
79113
:param y_pred: Predicted values of labels
80-
:return: Score of the prediction roc_auc
114+
:return: Score of the prediction ROC AUC
81115
"""
82116
y_true_, y_pred_ = transform(y_true, y_pred)
83117
if y_pred_.ndim == y_true_.ndim == 1:
@@ -90,9 +124,13 @@ def prediction_roc_auc(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) ->
90124

91125

92126
def prediction_precision(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
93-
"""
127+
r"""
94128
Calculate prediction precision. Supports both multiclass and multilabel.
95129
130+
This function internally uses :func:`sklearn.metrics.precision_score` with `average=macro`. Refer to the
131+
`scikit-learn documentation <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html>`__
132+
for more details.
133+
96134
:param y_true: True values of labels
97135
:param y_pred: Predicted values of labels
98136
:return: Score of the prediction precision
@@ -101,9 +139,13 @@ def prediction_precision(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -
101139

102140

103141
def prediction_recall(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
104-
"""
142+
r"""
105143
Calculate prediction recall. Supports both multiclass and multilabel.
106144
145+
This function internally uses :func:`sklearn.metrics.recall_score` with `average=macro`. Refer to the
146+
`scikit-learn documentation <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html>`__
147+
for more details.
148+
107149
:param y_true: True values of labels
108150
:param y_pred: Predicted values of labels
109151
:return: Score of the prediction recall
@@ -112,9 +154,13 @@ def prediction_recall(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> f
112154

113155

114156
def prediction_f1(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
115-
"""
157+
r"""
116158
Calculate prediction f1 score. Supports both multiclass and multilabel.
117159
160+
This function internally uses :func:`sklearn.metrics.f1_score` with `average=macro`. Refer to the
161+
`scikit-learn documentation <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html>`__
162+
for more details.
163+
118164
:param y_true: True values of labels
119165
:param y_pred: Predicted values of labels
120166
:return: Score of the prediction accuracy

autointent/metrics/regexp.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,22 @@ def __call__(self, y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> floa
2323

2424

2525
def regexp_partial_accuracy(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
26-
"""
26+
r"""
2727
Calculate regexp partial accuracy.
2828
29+
The regexp partial accuracy is calculated as:
30+
31+
.. math::
32+
33+
\text{Partial Accuracy} = \frac{\sum_{i=1}^N \mathbb{1}(y_{\text{true},i} \in y_{\text{pred},i})}{N}
34+
35+
where:
36+
- :math:`N` is the total number of samples,
37+
- :math:`y_{\text{true},i}` is the true label for the :math:`i`-th sample,
38+
- :math:`y_{\text{pred},i}` is the predicted label for the :math:`i`-th sample,
39+
- :math:`\mathbb{1}(\text{condition})` is the indicator function that equals 1 if the condition
40+
is true and 0 otherwise.
41+
2942
:param y_true: True values of labels
3043
:param y_pred: Predicted values of labels
3144
:return: Score of the regexp metric
@@ -39,9 +52,24 @@ def regexp_partial_accuracy(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE
3952

4053

4154
def regexp_partial_precision(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float:
42-
"""
55+
r"""
4356
Calculate regexp partial precision.
4457
58+
The regexp partial precision is calculated as:
59+
60+
.. math::
61+
62+
\text{Partial Precision} = \frac{\sum_{i=1}^N \mathbb{1}(y_{\text{true},i}
63+
\in y_{\text{pred},i})}{\sum_{i=1}^N \mathbb{1}(|y_{\text{pred},i}| > 0)}
64+
65+
where:
66+
- :math:`N` is the total number of samples,
67+
- :math:`y_{\text{true},i}` is the true label for the :math:`i`-th sample,
68+
- :math:`y_{\text{pred},i}` is the predicted label for the :math:`i`-th sample,
69+
- :math:`|y_{\text{pred},i}|` is the number of predicted labels for the :math:`i`-th sample,
70+
- :math:`\mathbb{1}(\text{condition})` is the indicator function that equals 1 if the condition
71+
is true and 0 otherwise.
72+
4573
:param y_true: True values of labels
4674
:param y_pred: Predicted values of labels
4775
:return: Score of the regexp metric

0 commit comments

Comments
 (0)