@@ -272,28 +272,6 @@ def plotPRPareto(dfs, only_high_p=False, ncol=6):
272
272
return ax
273
273
274
274
275
- def plotPRParetoSingle (df , graph ):
276
- plt .rcParams .update ({"font.size" : 20 })
277
- clusterers = df ["Clusterer Name" ].unique ()
278
-
279
- lines = [] # To store the Line2D objects for the legend
280
- labels = [] # To store the corresponding labels for the Line2D objects
281
-
282
- fig , ax = plt .subplots (nrows = 1 , ncols = 1 , figsize = (8 , 8 ))
283
-
284
- plotPRParetoAX (ax , graph , df , clusterers , lines , labels )
285
-
286
- fig .subplots_adjust (hspace = 0.4 )
287
- fig .legend (
288
- lines ,
289
- labels ,
290
- loc = "upper left" ,
291
- ncol = 1 ,
292
- bbox_to_anchor = (0.9 , 0.8 ),
293
- frameon = False ,
294
- )
295
-
296
-
297
275
# compute the area under the precision recall pareto curve, for precision >= 0.5.
298
276
def computeAUC (df_pr_pareto , clusterer , graph ):
299
277
df = df_pr_pareto [
@@ -380,6 +358,46 @@ def getAUCTable(df, df_pr_pareto, print_table=False):
380
358
print (latex_table )
381
359
382
360
361
+ def plot_single_threshold (threshold , df_pcbs ):
362
+ graphs = df_pcbs ["Input Graph" ].unique ()
363
+ assert len (graphs )== 1
364
+ graph = graphs [0 ]
365
+
366
+
367
+ clusterers = df_pcbs ["Clusterer Name" ].unique ()
368
+ df_pr_pareto = FilterParetoPRMethod (df_pcbs )
369
+ getAUCTable (df_pcbs , df_pr_pareto )
370
+
371
+ fig , axs = plt .subplots (nrows = 1 , ncols = 2 , figsize = (16 , 8 ))
372
+ plt .rcParams .update ({"font.size" : 25 })
373
+
374
+ lines = [] # To store the Line2D objects for the legend
375
+ labels = [] # To store the corresponding labels for the Line2D objects
376
+ plotPRParetoAX (axs [0 ], graph , df_pr_pareto , clusterers , lines , labels , only_high_p = True )
377
+
378
+ # Plot F_0.5 runtime Pareto frontier for PCBS methods
379
+ dfs , graphs = GetParetoDfs (df_pcbs )
380
+ plotParetoAxis (axs [1 ], dfs , graph , [], [], clusterers )
381
+
382
+ for ax in axs :
383
+ ax .set_title ("" )
384
+ ax .set_xlabel (ax .get_xlabel (), fontsize = 25 )
385
+ ax .set_ylabel (ax .get_ylabel (), fontsize = 25 )
386
+ ax .tick_params (axis = 'both' , which = 'major' , labelsize = 25 )
387
+
388
+ plt .tight_layout ()
389
+ fig .subplots_adjust (hspace = 0.4 )
390
+ fig .legend (
391
+ lines ,
392
+ labels ,
393
+ loc = "upper center" ,
394
+ ncol = 4 ,
395
+ bbox_to_anchor = (0.5 , 1.2 ),
396
+ frameon = False ,
397
+ )
398
+ plt .savefig (base_addr + f"ngrams_{ threshold } .pdf" , bbox_inches = "tight" )
399
+ print (f"plotted ngrams_{ threshold } .pdf" )
400
+
383
401
base_addr = "./results/"
384
402
385
403
@@ -436,21 +454,8 @@ def get_threshold_df(threshold):
436
454
# plot single example
437
455
threshold = 0.92
438
456
df_pcbs = get_threshold_df (threshold )
439
- df_pr_pareto = FilterParetoPRMethod (df_pcbs )
440
- getAUCTable (df_pcbs , df_pr_pareto )
441
- ax = plotPRPareto ({threshold :df_pr_pareto }, only_high_p = True , ncol = 3 )
442
- ax .set_title ("" )
443
- plt .savefig (base_addr + f"pr_ngrams_{ threshold } .pdf" , bbox_inches = "tight" )
444
- print (f"plotted pr_ngrams_{ threshold } .pdf" )
457
+ plot_single_threshold (threshold , df_pcbs )
445
458
446
- # Plot F_0.5 runtime Pareto frontier for PCBS methods
447
- clusterers = df_pcbs ["Clusterer Name" ].unique ()
448
- dfs , graphs = GetParetoDfs (df_pcbs )
449
- ax = plotPareto (dfs , graphs , clusterers , draw_legend = False )
450
- ax .set_title ("" )
451
- plt .tight_layout ()
452
- plt .savefig (base_addr + f"time_f1_ngrams_{ threshold } .pdf" , bbox_inches = "tight" )
453
- print (f"plotted time_f1_ngrams_{ threshold } .pdf" )
454
459
455
460
if __name__ == "__main__" :
456
461
base_addr = "results/"
0 commit comments