@@ -1072,7 +1072,7 @@ def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10,
10721072
10731073def plot_cumulative_gain (y_true , y_probas , title = 'Cumulative Gains Curve' ,
10741074 ax = None , figsize = None , title_fontsize = "large" ,
1075- text_fontsize = "medium" ):
1075+ text_fontsize = "medium" , class_names = None ):
10761076 """Generates the Cumulative Gains Plot from labels and scores/probabilities
10771077
10781078 The cumulative gains chart is used to determine the effectiveness of a
@@ -1104,6 +1104,10 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
11041104 text_fontsize (string or int, optional): Matplotlib-style fontsizes.
11051105 Use e.g. "small", "medium", "large" or integer-values. Defaults to
11061106 "medium".
1107+
1108+ class_names (list of strings, optional): List of class names. Used for
1109+ the legend. Order should be synchronized with the order of classes
1110+ in y_probas.
11071111
11081112 Returns:
11091113 ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
@@ -1126,6 +1130,7 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
11261130 y_probas = np .array (y_probas )
11271131
11281132 classes = np .unique (y_true )
1133+ if class_names is None : class_names = classes
11291134 if len (classes ) != 2 :
11301135 raise ValueError ('Cannot calculate Cumulative Gains for data with '
11311136 '{} category/ies' .format (len (classes )))
@@ -1141,8 +1146,8 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
11411146
11421147 ax .set_title (title , fontsize = title_fontsize )
11431148
1144- ax .plot (percentages , gains1 , lw = 3 , label = 'Class {}' .format (classes [0 ]))
1145- ax .plot (percentages , gains2 , lw = 3 , label = 'Class {}' .format (classes [1 ]))
1149+ ax .plot (percentages , gains1 , lw = 3 , label = 'Class {}' .format (class_names [0 ]))
1150+ ax .plot (percentages , gains2 , lw = 3 , label = 'Class {}' .format (class_names [1 ]))
11461151
11471152 ax .set_xlim ([0.0 , 1.0 ])
11481153 ax .set_ylim ([0.0 , 1.0 ])
@@ -1160,7 +1165,7 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
11601165
11611166def plot_lift_curve (y_true , y_probas , title = 'Lift Curve' ,
11621167 ax = None , figsize = None , title_fontsize = "large" ,
1163- text_fontsize = "medium" ):
1168+ text_fontsize = "medium" , class_names = None ):
11641169 """Generates the Lift Curve from labels and scores/probabilities
11651170
11661171 The lift curve is used to determine the effectiveness of a
@@ -1192,6 +1197,10 @@ def plot_lift_curve(y_true, y_probas, title='Lift Curve',
11921197 text_fontsize (string or int, optional): Matplotlib-style fontsizes.
11931198 Use e.g. "small", "medium", "large" or integer-values. Defaults to
11941199 "medium".
1200+
1201+ class_names (list of strings, optional): List of class names. Used for
1202+ the legend. Order should be synchronized with the order of classes
1203+ in y_probas.
11951204
11961205 Returns:
11971206 ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
@@ -1214,6 +1223,7 @@ def plot_lift_curve(y_true, y_probas, title='Lift Curve',
12141223 y_probas = np .array (y_probas )
12151224
12161225 classes = np .unique (y_true )
1226+ if class_names is None : class_names = classes
12171227 if len (classes ) != 2 :
12181228 raise ValueError ('Cannot calculate Lift Curve for data with '
12191229 '{} category/ies' .format (len (classes )))
@@ -1236,8 +1246,8 @@ def plot_lift_curve(y_true, y_probas, title='Lift Curve',
12361246
12371247 ax .set_title (title , fontsize = title_fontsize )
12381248
1239- ax .plot (percentages , gains1 , lw = 3 , label = 'Class {}' .format (classes [0 ]))
1240- ax .plot (percentages , gains2 , lw = 3 , label = 'Class {}' .format (classes [1 ]))
1249+ ax .plot (percentages , gains1 , lw = 3 , label = 'Class {}' .format (class_names [0 ]))
1250+ ax .plot (percentages , gains2 , lw = 3 , label = 'Class {}' .format (class_names [1 ]))
12411251
12421252 ax .plot ([0 , 1 ], [1 , 1 ], 'k--' , lw = 2 , label = 'Baseline' )
12431253
0 commit comments