@@ -132,7 +132,6 @@ def plotParetoAxis(ax, dfs, graph, lines, labels, clusterers):
132
132
# Extract the pareto_df for the current graph and clusterer combination
133
133
_ , pareto_df = dfs [(graph , clusterer )]
134
134
if pareto_df .empty :
135
- # print(graph, clusterer)
136
135
continue
137
136
138
137
# Plot the pareto_df with the appropriate marker
@@ -159,72 +158,35 @@ def plotParetoAxis(ax, dfs, graph, lines, labels, clusterers):
159
158
160
159
161
160
def plotPareto (dfs , graphs , clusterers , draw_legend = True , ncol = 6 ):
161
+ assert len (graphs )== 1
162
+ fig , ax = plt .subplots (nrows = 1 , ncols = 1 , figsize = (8 , 8 ))
163
+ plt .rcParams .update ({"font.size" : 20 })
162
164
163
- if len (graphs ) > 4 :
164
- plt .rcParams .update ({"font.size" : 25 })
165
-
166
- # Create subplots in a 2x3 grid
167
- fig , axes = plt .subplots (nrows = 2 , ncols = 3 , figsize = (22 , 15 ))
168
- graph_idx = 0
169
-
170
- lines = [] # To store the Line2D objects for the legend
171
- labels = [] # To store the corresponding labels for the Line2D objects
172
-
173
- for i in range (2 ):
174
- for j in range (3 ):
175
- if graph_idx < len (graphs ): # Ensure we have a graph to process
176
- graph = graphs [graph_idx ]
177
- ax = axes [i ][j ]
178
- plotParetoAxis (ax , dfs , graph , lines , labels , clusterers )
179
- graph_idx += 1
180
- else :
181
- axes [i ][j ].axis ("off" ) # Turn off axes without data
165
+ lines = [] # To store the Line2D objects for the legend
166
+ labels = [] # To store the corresponding labels for the Line2D objects
167
+
168
+ graph = graphs [0 ]
169
+ plotParetoAxis (ax , dfs , graph , lines , labels , clusterers )
170
+
171
+ if draw_legend :
182
172
# Create a single legend for the entire figure, at the top
183
173
fig .legend (
184
174
lines ,
185
175
labels ,
186
176
loc = "upper center" ,
187
177
ncol = ncol ,
188
- bbox_to_anchor = (0.5 , 1.1 ),
178
+ bbox_to_anchor = (0.5 , 1.15 ),
189
179
frameon = False ,
190
180
)
191
- else :
192
- # Create subplots in a 2x3 grid
193
- plt .rcParams .update ({"font.size" : 20 })
194
-
195
- fig , axes = plt .subplots (nrows = 1 , ncols = 4 , figsize = (25 , 5 ))
196
- graph_idx = 0
197
-
198
- lines = [] # To store the Line2D objects for the legend
199
- labels = [] # To store the corresponding labels for the Line2D objects
200
-
201
- for graph_idx in range (4 ):
202
- if graph_idx < len (graphs ): # Ensure we have a graph to process
203
- graph = graphs [graph_idx ]
204
- ax = axes [graph_idx ]
205
- plotParetoAxis (ax , dfs , graph , lines , labels , clusterers )
206
- graph_idx += 1
207
- else :
208
- axes [graph_idx ].axis ("off" ) # Turn off axes without data
209
- if draw_legend :
210
- # Create a single legend for the entire figure, at the top
211
- fig .legend (
212
- lines ,
213
- labels ,
214
- loc = "upper center" ,
215
- ncol = 6 ,
216
- bbox_to_anchor = (0.5 , 1.15 ),
217
- frameon = False ,
218
- )
219
181
220
- return fig
182
+ return ax
221
183
222
184
223
185
def plotPRParetoAX (ax , graph , df , clusterers , lines , labels , only_high_p = False ):
224
186
for clusterer in clusterers :
225
187
# Extract the pareto_df for the current graph and clusterer combination
226
188
pareto_df = df [
227
- (df ["Clusterer Name" ] == clusterer ) & (df ["Input Graph" ] == graph )
189
+ (df ["Clusterer Name" ] == clusterer ) # & (df["Input Graph"] == graph)
228
190
]
229
191
if pareto_df .empty :
230
192
continue
@@ -253,50 +215,51 @@ def plotPRParetoAX(ax, graph, df, clusterers, lines, labels, only_high_p=False):
253
215
ax .set_xlim ((0.5 , 1 ))
254
216
255
217
256
- def plotPRPareto (df , only_high_p = False , ncol = 6 ):
257
- graphs = df ["Input Graph" ].unique ()
258
- clusterers = df ["Clusterer Name" ].unique ()
218
+ def plotPRPareto (dfs , only_high_p = False , ncol = 6 ):
259
219
260
220
graph_idx = 0
261
221
262
222
lines = [] # To store the Line2D objects for the legend
263
223
labels = [] # To store the corresponding labels for the Line2D objects
264
224
265
- if len (graphs ) > 4 :
266
- plt .rcParams .update ({"font.size" : 25 })
267
- fig , axes = plt .subplots (nrows = 2 , ncols = 3 , figsize = (30 , 16 ))
268
- for i in range (2 ):
269
- for j in range (3 ):
270
- if graph_idx < len (graphs ): # Ensure we have a graph to process
271
- graph = graphs [graph_idx ]
272
- ax = axes [i ][j ]
225
+
226
+ num_params = len (dfs )
227
+
228
+ plt .rcParams .update ({"font.size" : 20 })
273
229
274
- plotPRParetoAX (
275
- ax , graph , df , clusterers , lines , labels , only_high_p
276
- )
230
+ if num_params > 1 :
231
+ fig , axes = plt .subplots (nrows = 1 , ncols = num_params , figsize = (25 , 5 ))
232
+ for param_idx , param in enumerate (dfs .keys ()):
233
+ df = dfs [param ]
234
+ graphs = df ["Input Graph" ].unique ()
235
+ clusterers = df ["Clusterer Name" ].unique ()
236
+ assert len (graphs )== 1
237
+ graph = graphs [0 ]
238
+
239
+ ax = axes [param_idx ]
240
+
241
+ plotPRParetoAX (ax , f"{ graph } _{ param } " , df , clusterers , lines , labels , only_high_p )
277
242
278
- graph_idx += 1
279
- else :
280
- axes [i ][j ].axis ("off" ) # Turn off axes without data
281
243
282
244
fig .legend (
283
245
lines ,
284
246
labels ,
285
247
loc = "upper center" ,
286
248
ncol = ncol ,
287
- bbox_to_anchor = (0.5 , 1 ),
249
+ bbox_to_anchor = (0.5 , 1.15 ),
288
250
frameon = False ,
289
251
)
252
+ return axes
253
+
290
254
else :
291
- plt .rcParams .update ({"font.size" : 20 })
292
- fig , axes = plt .subplots (nrows = 1 , ncols = 4 , figsize = (25 , 5 ))
293
- for graph_idx in range (len (graphs )):
294
- graph = graphs [graph_idx ]
295
- ax = axes [graph_idx ]
296
-
297
- plotPRParetoAX (ax , graph , df , clusterers , lines , labels , only_high_p )
298
-
299
- graph_idx += 1
255
+ fig , ax = plt .subplots (nrows = 1 , ncols = 1 , figsize = (8 , 8 ))
256
+ param = [k for k in dfs .keys ()][0 ]
257
+ df = dfs [param ]
258
+ graphs = df ["Input Graph" ].unique ()
259
+ clusterers = df ["Clusterer Name" ].unique ()
260
+ assert len (graphs )== 1
261
+ graph = graphs [0 ]
262
+ plotPRParetoAX (ax , f"{ graph } _{ param } " , df , clusterers , lines , labels , only_high_p )
300
263
301
264
fig .legend (
302
265
lines ,
@@ -305,8 +268,8 @@ def plotPRPareto(df, only_high_p=False, ncol=6):
305
268
ncol = ncol ,
306
269
bbox_to_anchor = (0.5 , 1.15 ),
307
270
frameon = False ,
308
- )
309
- return axes
271
+ )
272
+ return ax
310
273
311
274
312
275
def plotPRParetoSingle (df , graph ):
@@ -420,9 +383,9 @@ def getAUCTable(df, df_pr_pareto, print_table=False):
420
383
421
384
422
385
def plot_ngrams ():
423
- # df_pcbs = pd.read_csv(base_addr + f"out_ngrams_pcbs_csv/stats.csv")
386
+ df_pcbs = pd .read_csv (base_addr + f"out_ngrams_pcbs_csv/stats.csv" )
424
387
df_pcbs_high_res = pd .read_csv (base_addr + f"out_ngrams_high_res_pcbs_csv/stats.csv" )
425
- df = pd .concat ([df_pcbs_high_res ]) # df_pcbs,
388
+ df = pd .concat ([df_pcbs , df_pcbs_high_res ])
426
389
427
390
df = df .dropna (how = "all" )
428
391
replace_graph_names (df )
@@ -447,32 +410,46 @@ def plot_ngrams():
447
410
"ParHACClusterer_1" ,
448
411
]
449
412
450
-
451
- thresholds = [0.86 , 0.88 , 0.90 , 0.92 , 0.94 ]
452
-
453
- for threshold in thresholds :
413
+ def get_threshold_df (threshold ):
454
414
df_pcbs = df [df ["Clusterer Name" ].isin (our_methods )]
455
415
456
416
df_pcbs ["fScore_mean" ] = df ["fScore_mean" ].apply (lambda k : k [threshold ])
457
417
df_pcbs ["communityPrecision_mean" ] = df ["communityPrecision_mean" ].apply (lambda k : k [threshold ])
458
418
df_pcbs ["communityRecall_mean" ] = df ["communityRecall_mean" ].apply (lambda k : k [threshold ])
419
+ return df_pcbs
420
+
421
+ thresholds = [0.88 , 0.90 , 0.92 , 0.94 ]
422
+ df_pr_paretos = {}
423
+
424
+ for threshold in thresholds :
425
+ df_pcbs = get_threshold_df (threshold )
459
426
460
- # Get AUC table
461
427
df_pr_pareto = FilterParetoPRMethod (df_pcbs )
462
- getAUCTable (df_pcbs , df_pr_pareto )
463
-
464
- # Plot Precision Recall Pareto frontier for PCBS methods
465
- axes = plotPRPareto (df_pr_pareto , only_high_p = True ) #
466
- plt .savefig (base_addr + f"pr_uci_{ threshold } .pdf" , bbox_inches = "tight" )
467
- print ("plotted pr_uci.pdf" )
468
-
469
- # Plot F_0.5 runtime Pareto frontier for PCBS methods
470
- clusterers = df_pcbs ["Clusterer Name" ].unique ()
471
- dfs , graphs = GetParetoDfs (df_pcbs )
472
- plotPareto (dfs , graphs , clusterers )
473
- plt .tight_layout ()
474
- plt .savefig (base_addr + f"time_f1_uci_{ threshold } .pdf" , bbox_inches = "tight" )
475
- print ("plotted time_f1_uci.pdf" )
428
+ df_pr_paretos [threshold ] = df_pr_pareto
429
+
430
+ # Plot Precision Recall Pareto frontier for PCBS methods
431
+ plotPRPareto (df_pr_paretos , only_high_p = True ) #
432
+ plt .savefig (base_addr + f"pr_uci.pdf" , bbox_inches = "tight" )
433
+ print ("plotted pr_uci.pdf" )
434
+
435
+ # plot single example
436
+ threshold = 0.92
437
+ df_pcbs = get_threshold_df (threshold )
438
+ df_pr_pareto = FilterParetoPRMethod (df_pcbs )
439
+ getAUCTable (df_pcbs , df_pr_pareto )
440
+ ax = plotPRPareto ({threshold :df_pr_pareto }, only_high_p = True , ncol = 3 )
441
+ ax .set_title ("" )
442
+ plt .savefig (base_addr + f"pr_uci_{ threshold } .pdf" , bbox_inches = "tight" )
443
+ print (f"plotted pr_uci_{ threshold } .pdf" )
444
+
445
+ # Plot F_0.5 runtime Pareto frontier for PCBS methods
446
+ clusterers = df_pcbs ["Clusterer Name" ].unique ()
447
+ dfs , graphs = GetParetoDfs (df_pcbs )
448
+ ax = plotPareto (dfs , graphs , clusterers , draw_legend = False )
449
+ ax .set_title ("" )
450
+ plt .tight_layout ()
451
+ plt .savefig (base_addr + f"time_f1_uci_{ threshold } .pdf" , bbox_inches = "tight" )
452
+ print (f"plotted time_f1_uci_{ threshold } .pdf" )
476
453
477
454
if __name__ == "__main__" :
478
455
base_addr = "results/"
0 commit comments