@@ -806,6 +806,8 @@ def classification_report_imbalanced(
806
806
sample_weight = None ,
807
807
digits = 2 ,
808
808
alpha = 0.1 ,
809
+ output_dict = False ,
810
+ zero_division = "warn" ,
809
811
):
810
812
"""Build a classification report based on metrics used with imbalanced
811
813
dataset
@@ -816,38 +818,59 @@ def classification_report_imbalanced(
816
818
mean, and index balanced accuracy of the
817
819
geometric mean.
818
820
821
+ Read more in the :ref:`User Guide <classification_report>`.
822
+
819
823
Parameters
820
824
----------
821
- y_true : ndarray, shape (n_samples, )
825
+ y_true : 1d array-like, or label indicator array / sparse matrix
822
826
Ground truth (correct) target values.
823
827
824
- y_pred : ndarray, shape (n_samples, )
828
+ y_pred : 1d array-like, or label indicator array / sparse matrix
825
829
Estimated targets as returned by a classifier.
826
830
827
- labels : list, optional
828
- The set of labels to include when ``average != 'binary'``, and their
829
- order if ``average is None``. Labels present in the data can be
830
- excluded, for example to calculate a multiclass average ignoring a
831
- majority negative class, while labels not present in the data will
832
- result in 0 components in a macro average.
831
+ labels : array-like of shape (n_labels,), default=None
832
+ Optional list of label indices to include in the report.
833
833
834
- target_names : list of strings, optional
834
+ target_names : list of str of shape (n_labels,), default=None
835
835
Optional display names matching the labels (same order).
836
836
837
- sample_weight : ndarray, shape (n_samples, )
837
+ sample_weight : array-like of shape (n_samples,), default=None
838
838
Sample weights.
839
839
840
- digits : int, optional (default=2)
841
- Number of digits for formatting output floating point values
840
+ digits : int, default=2
841
+ Number of digits for formatting output floating point values.
842
+ When ``output_dict`` is ``True``, this will be ignored and the
843
+ returned values will not be rounded.
842
844
843
- alpha : float, optional ( default=0.1)
845
+ alpha : float, default=0.1
844
846
Weighting factor.
845
847
848
+ output_dict : bool, default=False
849
+ If True, return output as dict.
850
+
851
+ .. versionadded:: 0.7
852
+
853
+ zero_division : "warn" or {0, 1}, default="warn"
854
+ Sets the value to return when there is a zero division. If set to
855
+ "warn", this acts as 0, but warnings are also raised.
856
+
857
+ .. versionadded:: 0.7
858
+
846
859
Returns
847
860
-------
848
- report : string
861
+ report : string / dict
849
862
Text summary of the precision, recall, specificity, geometric mean,
850
863
and index balanced accuracy.
864
+ Dictionary returned if output_dict is True. Dictionary has the
865
+ following structure::
866
+
867
+ {'label 1': {'pre':0.5,
868
+ 'rec':1.0,
869
+ ...
870
+ },
871
+ 'label 2': { ... },
872
+ ...
873
+ }
851
874
852
875
Examples
853
876
--------
@@ -883,7 +906,7 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.64\
883
906
last_line_heading = "avg / total"
884
907
885
908
if target_names is None :
886
- target_names = ["%s" % l for l in labels ]
909
+ target_names = [f" { label } " for label in labels ]
887
910
name_width = max (len (cn ) for cn in target_names )
888
911
width = max (name_width , len (last_line_heading ), digits )
889
912
@@ -905,6 +928,7 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.64\
905
928
labels = labels ,
906
929
average = None ,
907
930
sample_weight = sample_weight ,
931
+ zero_division = zero_division
908
932
)
909
933
# Specificity
910
934
specificity = specificity_score (
@@ -934,33 +958,50 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.64\
934
958
sample_weight = sample_weight ,
935
959
)
936
960
961
+ report_dict = {}
937
962
for i , label in enumerate (labels ):
963
+ report_dict_label = {}
938
964
values = [target_names [i ]]
939
- for v in (
940
- precision [i ],
941
- recall [i ],
942
- specificity [i ],
943
- f1 [i ],
944
- geo_mean [i ],
945
- iba [i ],
965
+ for score_name , score_value in zip (
966
+ headers [1 :- 1 ],
967
+ [
968
+ precision [i ],
969
+ recall [i ],
970
+ specificity [i ],
971
+ f1 [i ],
972
+ geo_mean [i ],
973
+ iba [i ],
974
+ ]
946
975
):
947
- values += ["{0:0.{1}f}" .format (v , digits )]
948
- values += ["{}" .format (support [i ])]
976
+ values += ["{0:0.{1}f}" .format (score_value , digits )]
977
+ report_dict_label [score_name ] = score_value
978
+ values += [f"{ support [i ]} " ]
979
+ report_dict_label [headers [- 1 ]] = support [i ]
949
980
report += fmt % tuple (values )
950
981
982
+ report_dict [label ] = report_dict_label
983
+
951
984
report += "\n "
952
985
953
986
# compute averages
954
987
values = [last_line_heading ]
955
- for v in (
956
- np .average (precision , weights = support ),
957
- np .average (recall , weights = support ),
958
- np .average (specificity , weights = support ),
959
- np .average (f1 , weights = support ),
960
- np .average (geo_mean , weights = support ),
961
- np .average (iba , weights = support ),
988
+ for score_name , score_value in zip (
989
+ headers [1 :- 1 ],
990
+ [
991
+ np .average (precision , weights = support ),
992
+ np .average (recall , weights = support ),
993
+ np .average (specificity , weights = support ),
994
+ np .average (f1 , weights = support ),
995
+ np .average (geo_mean , weights = support ),
996
+ np .average (iba , weights = support ),
997
+ ]
962
998
):
963
- values += ["{0:0.{1}f}" .format (v , digits )]
964
- values += ["{}" .format (np .sum (support ))]
999
+ values += ["{0:0.{1}f}" .format (score_value , digits )]
1000
+ report_dict [f"avg_{ score_name } " ] = score_value
1001
+ values += [f"{ np .sum (support )} " ]
965
1002
report += fmt % tuple (values )
1003
+ report_dict ["total_support" ] = np .sum (support )
1004
+
1005
+ if output_dict :
1006
+ return report_dict
966
1007
return report
0 commit comments