Skip to content

Commit 8fc6dbb

Browse files
authored
ENH add output_dict in classification_report_imbalanced (scikit-learn-contrib#770)
1 parent 599e24f commit 8fc6dbb

File tree

4 files changed

+121
-33
lines changed

4 files changed

+121
-33
lines changed

doc/metrics.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,13 @@ of the classes while keeping these accuracies balanced.
4444
The :func:`make_index_balanced_accuracy` :cite:`garcia2012effectiveness` can
4545
wrap any metric and give more importance to a specific class using the
4646
parameter ``alpha``.
47+
48+
.. _classification_report:
49+
50+
Summary of important metrics
51+
----------------------------
52+
53+
The :func:`classification_report_imbalanced` will compute a set of metrics
54+
per class and summarize it in a table. The parameter `output_dict` allows
55+
to get a string or a Python dictionary. This dictionary can be reused to create
56+
a Pandas dataframe for instance.

doc/whats_new/v0.7.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ Enhancements
7171
- Added Random Over-Sampling Examples (ROSE) class.
7272
:pr:`754` by :user:`Andrea Lorenzon <andrealorenzon>`.
7373

74+
- Add option `output_dict` in
75+
:func:`imblearn.metrics.classification_report_imbalanced` to return a
76+
dictionary instead of a string.
77+
:pr:`xx` by :user:`Guillaume Lemaitre <glemaitre>`.
78+
7479
Deprecation
7580
...........
7681

imblearn/metrics/_classification.py

Lines changed: 74 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,8 @@ def classification_report_imbalanced(
806806
sample_weight=None,
807807
digits=2,
808808
alpha=0.1,
809+
output_dict=False,
810+
zero_division="warn",
809811
):
810812
"""Build a classification report based on metrics used with imbalanced
811813
dataset
@@ -816,38 +818,59 @@ def classification_report_imbalanced(
816818
mean, and index balanced accuracy of the
817819
geometric mean.
818820
821+
Read more in the :ref:`User Guide <classification_report>`.
822+
819823
Parameters
820824
----------
821-
y_true : ndarray, shape (n_samples, )
825+
y_true : 1d array-like, or label indicator array / sparse matrix
822826
Ground truth (correct) target values.
823827
824-
y_pred : ndarray, shape (n_samples, )
828+
y_pred : 1d array-like, or label indicator array / sparse matrix
825829
Estimated targets as returned by a classifier.
826830
827-
labels : list, optional
828-
The set of labels to include when ``average != 'binary'``, and their
829-
order if ``average is None``. Labels present in the data can be
830-
excluded, for example to calculate a multiclass average ignoring a
831-
majority negative class, while labels not present in the data will
832-
result in 0 components in a macro average.
831+
labels : array-like of shape (n_labels,), default=None
832+
Optional list of label indices to include in the report.
833833
834-
target_names : list of strings, optional
834+
target_names : list of str of shape (n_labels,), default=None
835835
Optional display names matching the labels (same order).
836836
837-
sample_weight : ndarray, shape (n_samples, )
837+
sample_weight : array-like of shape (n_samples,), default=None
838838
Sample weights.
839839
840-
digits : int, optional (default=2)
841-
Number of digits for formatting output floating point values
840+
digits : int, default=2
841+
Number of digits for formatting output floating point values.
842+
When ``output_dict`` is ``True``, this will be ignored and the
843+
returned values will not be rounded.
842844
843-
alpha : float, optional (default=0.1)
845+
alpha : float, default=0.1
844846
Weighting factor.
845847
848+
output_dict : bool, default=False
849+
If True, return output as dict.
850+
851+
.. versionadded:: 0.7
852+
853+
zero_division : "warn" or {0, 1}, default="warn"
854+
Sets the value to return when there is a zero division. If set to
855+
"warn", this acts as 0, but warnings are also raised.
856+
857+
.. versionadded:: 0.7
858+
846859
Returns
847860
-------
848-
report : string
861+
report : string / dict
849862
Text summary of the precision, recall, specificity, geometric mean,
850863
and index balanced accuracy.
864+
Dictionary returned if output_dict is True. Dictionary has the
865+
following structure::
866+
867+
{'label 1': {'pre':0.5,
868+
'rec':1.0,
869+
...
870+
},
871+
'label 2': { ... },
872+
...
873+
}
851874
852875
Examples
853876
--------
@@ -883,7 +906,7 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.64\
883906
last_line_heading = "avg / total"
884907

885908
if target_names is None:
886-
target_names = ["%s" % l for l in labels]
909+
target_names = [f"{label}" for label in labels]
887910
name_width = max(len(cn) for cn in target_names)
888911
width = max(name_width, len(last_line_heading), digits)
889912

@@ -905,6 +928,7 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.64\
905928
labels=labels,
906929
average=None,
907930
sample_weight=sample_weight,
931+
zero_division=zero_division
908932
)
909933
# Specificity
910934
specificity = specificity_score(
@@ -934,33 +958,50 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.64\
934958
sample_weight=sample_weight,
935959
)
936960

961+
report_dict = {}
937962
for i, label in enumerate(labels):
963+
report_dict_label = {}
938964
values = [target_names[i]]
939-
for v in (
940-
precision[i],
941-
recall[i],
942-
specificity[i],
943-
f1[i],
944-
geo_mean[i],
945-
iba[i],
965+
for score_name, score_value in zip(
966+
headers[1:-1],
967+
[
968+
precision[i],
969+
recall[i],
970+
specificity[i],
971+
f1[i],
972+
geo_mean[i],
973+
iba[i],
974+
]
946975
):
947-
values += ["{0:0.{1}f}".format(v, digits)]
948-
values += ["{}".format(support[i])]
976+
values += ["{0:0.{1}f}".format(score_value, digits)]
977+
report_dict_label[score_name] = score_value
978+
values += [f"{support[i]}"]
979+
report_dict_label[headers[-1]] = support[i]
949980
report += fmt % tuple(values)
950981

982+
report_dict[label] = report_dict_label
983+
951984
report += "\n"
952985

953986
# compute averages
954987
values = [last_line_heading]
955-
for v in (
956-
np.average(precision, weights=support),
957-
np.average(recall, weights=support),
958-
np.average(specificity, weights=support),
959-
np.average(f1, weights=support),
960-
np.average(geo_mean, weights=support),
961-
np.average(iba, weights=support),
988+
for score_name, score_value in zip(
989+
headers[1:-1],
990+
[
991+
np.average(precision, weights=support),
992+
np.average(recall, weights=support),
993+
np.average(specificity, weights=support),
994+
np.average(f1, weights=support),
995+
np.average(geo_mean, weights=support),
996+
np.average(iba, weights=support),
997+
]
962998
):
963-
values += ["{0:0.{1}f}".format(v, digits)]
964-
values += ["{}".format(np.sum(support))]
999+
values += ["{0:0.{1}f}".format(score_value, digits)]
1000+
report_dict[f"avg_{score_name}"] = score_value
1001+
values += [f"{np.sum(support)}"]
9651002
report += fmt % tuple(values)
1003+
report_dict["total_support"] = np.sum(support)
1004+
1005+
if output_dict:
1006+
return report_dict
9661007
return report

imblearn/metrics/tests/test_classification.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,3 +466,35 @@ def test_iba_error_y_score_prob_error(score_loss):
466466
aps = make_index_balanced_accuracy(alpha=0.5, squared=True)(score_loss)
467467
with pytest.raises(AttributeError):
468468
aps(y_true, y_pred)
469+
470+
471+
def test_classification_report_imbalanced_dict():
472+
iris = datasets.load_iris()
473+
y_true, y_pred, _ = make_prediction(dataset=iris, binary=False)
474+
475+
report = classification_report_imbalanced(
476+
y_true,
477+
y_pred,
478+
labels=np.arange(len(iris.target_names)),
479+
target_names=iris.target_names,
480+
output_dict=True,
481+
)
482+
outer_keys = set(report.keys())
483+
inner_keys = set(report[0].keys())
484+
485+
expected_outer_keys = {
486+
0,
487+
1,
488+
2,
489+
"avg_pre",
490+
"avg_rec",
491+
"avg_spe",
492+
"avg_f1",
493+
"avg_geo",
494+
"avg_iba",
495+
"total_support",
496+
}
497+
expected_inner_keys = {'spe', 'f1', 'sup', 'rec', 'geo', 'iba', 'pre'}
498+
499+
assert outer_keys == expected_outer_keys
500+
assert inner_keys == expected_inner_keys

0 commit comments

Comments
 (0)