@@ -29,9 +29,22 @@ def __call__(self, y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> floa
2929
3030
3131def prediction_accuracy (y_true : LABELS_VALUE_TYPE , y_pred : LABELS_VALUE_TYPE ) -> float :
32- """
32+ r """
3333 Calculate prediction accuracy. Supports both multiclass and multilabel.
3434
35+ The prediction accuracy is calculated as:
36+
37+ .. math::
38+
39+ \text{Accuracy} = \frac{\sum_{i=1}^N \mathbb{1}(y_{\text{true},i} = y_{\text{pred},i})}{N}
40+
41+ where:
42+ - :math:`N` is the total number of samples,
43+ - :math:`y_{\text{true},i}` is the true label for the :math:`i`-th sample,
44+ - :math:`y_{\text{pred},i}` is the predicted label for the :math:`i`-th sample,
45+ - :math:`\mathbb{1}(\text{condition})` is the indicator function that equals 1 if the condition
46+ is true and 0 otherwise.
47+
3548 :param y_true: True values of labels
3649 :param y_pred: Predicted values of labels
3750 :return: Score of the prediction accuracy
@@ -41,9 +54,22 @@ def prediction_accuracy(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) ->
4154
4255
4356def _prediction_roc_auc_multiclass (y_true : LABELS_VALUE_TYPE , y_pred : LABELS_VALUE_TYPE ) -> float :
44- """
57+ r """
4558 Calculate roc_auc for multiclass.
4659
60+ The ROC AUC score for multiclass is calculated as the mean ROC AUC score
61+ across all classes, where each class is treated as a binary classification task
62+ (one-vs-rest).
63+
64+ .. math::
65+
66+ \text{ROC AUC}_{\text{multiclass}} = \frac{1}{K} \sum_{k=1}^K \text{ROC AUC}_k
67+
68+ where:
69+ - :math:`K` is the number of classes,
70+ - :math:`\text{ROC AUC}_k` is the ROC AUC score for the :math:`k`-th class,
71+ calculated by treating it as a binary classification problem (class :math:`k` vs rest).
72+
4773 :param y_true: True values of labels
4874 :param y_pred: Predicted values of labels
4975 :return: Score of the prediction roc_auc
@@ -61,9 +87,13 @@ def _prediction_roc_auc_multiclass(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VAL
6187
6288
6389def _prediction_roc_auc_multilabel (y_true : LABELS_VALUE_TYPE , y_pred : LABELS_VALUE_TYPE ) -> float :
64- """
90+ r """
6591 Calculate roc_auc for multilabel.
6692
93+ This function internally uses :func:`sklearn.metrics.roc_auc_score` with `average=macro`. Refer to the
94+ `scikit-learn documentation <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html>`__
95+ for more details.
96+
6797 :param y_true: True values of labels
6898 :param y_pred: Predicted values of labels
6999 :return: Score of the prediction accuracy
@@ -72,12 +102,16 @@ def _prediction_roc_auc_multilabel(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VAL
72102
73103
74104def prediction_roc_auc (y_true : LABELS_VALUE_TYPE , y_pred : LABELS_VALUE_TYPE ) -> float :
75- """
76- Calculate roc_auc for multiclass and multilabel.
105+ r"""
106+ Calculate ROC AUC for multiclass and multilabel classification.
107+
108+ The ROC AUC measures the ability of a model to distinguish between classes.
109+ It is calculated as the area under the curve of the true positive rate (TPR)
110+ against the false positive rate (FPR) at various threshold settings.
77111
78112 :param y_true: True values of labels
79113 :param y_pred: Predicted values of labels
80- :return: Score of the prediction roc_auc
114+ :return: Score of the prediction ROC AUC
81115 """
82116 y_true_ , y_pred_ = transform (y_true , y_pred )
83117 if y_pred_ .ndim == y_true_ .ndim == 1 :
@@ -90,9 +124,13 @@ def prediction_roc_auc(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) ->
90124
91125
92126def prediction_precision (y_true : LABELS_VALUE_TYPE , y_pred : LABELS_VALUE_TYPE ) -> float :
93- """
127+ r """
94128 Calculate prediction precision. Supports both multiclass and multilabel.
95129
130+ This function internally uses :func:`sklearn.metrics.precision_score` with `average=macro`. Refer to the
131+ `scikit-learn documentation <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html>`__
132+ for more details.
133+
96134 :param y_true: True values of labels
97135 :param y_pred: Predicted values of labels
98136 :return: Score of the prediction precision
@@ -101,9 +139,13 @@ def prediction_precision(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -
101139
102140
103141def prediction_recall (y_true : LABELS_VALUE_TYPE , y_pred : LABELS_VALUE_TYPE ) -> float :
104- """
142+ r """
105143 Calculate prediction recall. Supports both multiclass and multilabel.
106144
145+ This function internally uses :func:`sklearn.metrics.recall_score` with `average=macro`. Refer to the
146+ `scikit-learn documentation <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html>`__
147+ for more details.
148+
107149 :param y_true: True values of labels
108150 :param y_pred: Predicted values of labels
109151 :return: Score of the prediction recall
@@ -112,9 +154,13 @@ def prediction_recall(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> f
112154
113155
114156def prediction_f1 (y_true : LABELS_VALUE_TYPE , y_pred : LABELS_VALUE_TYPE ) -> float :
115- """
157+ r """
116158 Calculate prediction f1 score. Supports both multiclass and multilabel.
117159
160+ This function internally uses :func:`sklearn.metrics.f1_score` with `average=macro`. Refer to the
161+ `scikit-learn documentation <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html>`__
162+ for more details.
163+
118164 :param y_true: True values of labels
119165 :param y_pred: Predicted values of labels
120166 :return: Score of the prediction accuracy
0 commit comments