@@ -107,7 +107,7 @@ def badrate_plot(frame, x = None, target = 'target', by = None,
107107 return unpack_tuple (res )
108108
109109
110- def corr_plot (frame , figure_size = (20 , 15 )):
110+ def corr_plot (frame , figure_size = (20 , 15 ), ax = None ):
111111 """plot for correlation
112112
113113 Args:
@@ -133,12 +133,13 @@ def corr_plot(frame, figure_size = (20, 15)):
133133 annot = True ,
134134 fmt = '.2f' ,
135135 figure_size = figure_size ,
136+ ax = ax ,
136137 )
137138
138139 return map_plot
139140
140141
141- def proportion_plot (x = None , keys = None ):
142+ def proportion_plot (x = None , keys = None , ax = None ):
142143 """plot for comparing proportion in different dataset
143144
144145 Args:
@@ -175,12 +176,13 @@ def proportion_plot(x = None, keys = None):
175176 y = 'proportion' ,
176177 hue = 'keys' ,
177178 data = prop_data ,
179+ ax = ax ,
178180 )
179181
180182 return prop_plot
181183
182184
183- def roc_plot (score , target , compare = None , figsize = (14 , 10 )):
185+ def roc_plot (score , target , compare = None , figsize = (14 , 10 ), ax = None ):
184186 """plot for roc
185187
186188 Args:
@@ -193,7 +195,9 @@ def roc_plot(score, target, compare = None, figsize = (14, 10)):
193195 """
194196 auc , fpr , tpr , thresholds = AUC (score , target , return_curve = True )
195197
196- fig , ax = plt .subplots (1 , 1 , figsize = figsize )
198+ if ax is None :
199+ fig , ax = plt .subplots (1 , 1 , figsize = figsize )
200+
197201 ax .plot (fpr , tpr , label = 'ROC curve (area = %0.5f)' % auc )
198202 ax .fill_between (fpr , tpr , alpha = 0.3 )
199203 if compare is not None :
@@ -206,7 +210,7 @@ def roc_plot(score, target, compare = None, figsize = (14, 10)):
206210
207211 return ax
208212
209- def ks_plot (score , target , figsize = (14 , 10 )):
213+ def ks_plot (score , target , figsize = (14 , 10 ), ax = None ):
210214 """plot for ks
211215
212216 Args:
@@ -219,7 +223,9 @@ def ks_plot(score, target, figsize = (14, 10)):
219223 """
220224 fpr , tpr , thresholds = roc_curve (target , score )
221225
222- fig , ax = plt .subplots (1 , 1 , figsize = figsize )
226+ if ax is None :
227+ fig , ax = plt .subplots (1 , 1 , figsize = figsize )
228+
223229 ax .plot (thresholds [1 : ], tpr [1 : ], label = 'tpr' )
224230 ax .plot (thresholds [1 : ], fpr [1 : ], label = 'fpr' )
225231 ax .plot (thresholds [1 : ], (tpr - fpr )[1 : ], label = 'ks' )
@@ -235,7 +241,8 @@ def ks_plot(score, target, figsize = (14, 10)):
235241
236242 return ax
237243
238- def bin_plot (frame , x = None , target = 'target' , iv = True , annotate_format = ".2f" , return_frame = False , figsize = (12 , 6 )):
244+ def bin_plot (frame , x = None , target = 'target' , iv = True , annotate_format = ".2f" ,
245+ return_frame = False , figsize = (12 , 6 ), ax = None ):
239246 """plot for bins
240247
241248 Args:
@@ -258,18 +265,21 @@ def bin_plot(frame, x = None, target = 'target', iv = True, annotate_format = ".
258265 target = temp_name
259266
260267 table = feature_bin_stats (frame , x , target )
261- fig , prop_ax = plt .subplots (figsize = figsize )
262- prop_ax = tadpole .barplot (
268+
269+ if ax is None :
270+ fig , ax = plt .subplots (figsize = figsize )
271+
272+ ax = tadpole .barplot (
263273 x = x ,
264274 y = 'prop' ,
265275 data = table ,
266276 color = '#82C6E2' ,
267- ax = prop_ax ,
277+ ax = ax ,
268278 )
269279
270- prop_ax = add_annotate (prop_ax , format = annotate_format )
280+ ax = add_annotate (ax , format = annotate_format )
271281
272- badrate_ax = prop_ax .twinx ()
282+ badrate_ax = ax .twinx ()
273283 badrate_ax .grid (False )
274284
275285 badrate_ax = tadpole .lineplot (
@@ -284,10 +294,10 @@ def bin_plot(frame, x = None, target = 'target', iv = True, annotate_format = ".
284294 badrate_ax = add_annotate (badrate_ax , format = annotate_format )
285295
286296 if iv :
287- prop_ax = reset_ylim (prop_ax )
288- prop_ax = add_text (prop_ax , 'IV: {:.5f}' .format (table ['iv' ].sum ()))
297+ ax = reset_ylim (ax )
298+ ax = add_text (ax , 'IV: {:.5f}' .format (table ['iv' ].sum ()))
289299
290- res = (prop_ax ,)
300+ res = (ax ,)
291301
292302 if return_frame :
293303 res += (table ,)
0 commit comments