Skip to content

Commit 7d73afb

Browse files
committed
plotting
1 parent 7a6d09c commit 7d73afb

File tree

1 file changed

+78
-101
lines changed

1 file changed

+78
-101
lines changed

plotting/plot_pareto_ngrams.py

Lines changed: 78 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ def plotParetoAxis(ax, dfs, graph, lines, labels, clusterers):
132132
# Extract the pareto_df for the current graph and clusterer combination
133133
_, pareto_df = dfs[(graph, clusterer)]
134134
if pareto_df.empty:
135-
# print(graph, clusterer)
136135
continue
137136

138137
# Plot the pareto_df with the appropriate marker
@@ -159,72 +158,35 @@ def plotParetoAxis(ax, dfs, graph, lines, labels, clusterers):
159158

160159

161160
def plotPareto(dfs, graphs, clusterers, draw_legend=True, ncol=6):
161+
assert len(graphs)==1
162+
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
163+
plt.rcParams.update({"font.size": 20})
162164

163-
if len(graphs) > 4:
164-
plt.rcParams.update({"font.size": 25})
165-
166-
# Create subplots in a 2x3 grid
167-
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(22, 15))
168-
graph_idx = 0
169-
170-
lines = [] # To store the Line2D objects for the legend
171-
labels = [] # To store the corresponding labels for the Line2D objects
172-
173-
for i in range(2):
174-
for j in range(3):
175-
if graph_idx < len(graphs): # Ensure we have a graph to process
176-
graph = graphs[graph_idx]
177-
ax = axes[i][j]
178-
plotParetoAxis(ax, dfs, graph, lines, labels, clusterers)
179-
graph_idx += 1
180-
else:
181-
axes[i][j].axis("off") # Turn off axes without data
165+
lines = [] # To store the Line2D objects for the legend
166+
labels = [] # To store the corresponding labels for the Line2D objects
167+
168+
graph = graphs[0]
169+
plotParetoAxis(ax, dfs, graph, lines, labels, clusterers)
170+
171+
if draw_legend:
182172
# Create a single legend for the entire figure, at the top
183173
fig.legend(
184174
lines,
185175
labels,
186176
loc="upper center",
187177
ncol=ncol,
188-
bbox_to_anchor=(0.5, 1.1),
178+
bbox_to_anchor=(0.5, 1.15),
189179
frameon=False,
190180
)
191-
else:
192-
# Create subplots in a 2x3 grid
193-
plt.rcParams.update({"font.size": 20})
194-
195-
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(25, 5))
196-
graph_idx = 0
197-
198-
lines = [] # To store the Line2D objects for the legend
199-
labels = [] # To store the corresponding labels for the Line2D objects
200-
201-
for graph_idx in range(4):
202-
if graph_idx < len(graphs): # Ensure we have a graph to process
203-
graph = graphs[graph_idx]
204-
ax = axes[graph_idx]
205-
plotParetoAxis(ax, dfs, graph, lines, labels, clusterers)
206-
graph_idx += 1
207-
else:
208-
axes[graph_idx].axis("off") # Turn off axes without data
209-
if draw_legend:
210-
# Create a single legend for the entire figure, at the top
211-
fig.legend(
212-
lines,
213-
labels,
214-
loc="upper center",
215-
ncol=6,
216-
bbox_to_anchor=(0.5, 1.15),
217-
frameon=False,
218-
)
219181

220-
return fig
182+
return ax
221183

222184

223185
def plotPRParetoAX(ax, graph, df, clusterers, lines, labels, only_high_p=False):
224186
for clusterer in clusterers:
225187
# Extract the pareto_df for the current graph and clusterer combination
226188
pareto_df = df[
227-
(df["Clusterer Name"] == clusterer) & (df["Input Graph"] == graph)
189+
(df["Clusterer Name"] == clusterer) #& (df["Input Graph"] == graph)
228190
]
229191
if pareto_df.empty:
230192
continue
@@ -253,50 +215,51 @@ def plotPRParetoAX(ax, graph, df, clusterers, lines, labels, only_high_p=False):
253215
ax.set_xlim((0.5, 1))
254216

255217

256-
def plotPRPareto(df, only_high_p=False, ncol=6):
257-
graphs = df["Input Graph"].unique()
258-
clusterers = df["Clusterer Name"].unique()
218+
def plotPRPareto(dfs, only_high_p=False, ncol=6):
259219

260220
graph_idx = 0
261221

262222
lines = [] # To store the Line2D objects for the legend
263223
labels = [] # To store the corresponding labels for the Line2D objects
264224

265-
if len(graphs) > 4:
266-
plt.rcParams.update({"font.size": 25})
267-
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(30, 16))
268-
for i in range(2):
269-
for j in range(3):
270-
if graph_idx < len(graphs): # Ensure we have a graph to process
271-
graph = graphs[graph_idx]
272-
ax = axes[i][j]
225+
226+
num_params = len(dfs)
227+
228+
plt.rcParams.update({"font.size": 20})
273229

274-
plotPRParetoAX(
275-
ax, graph, df, clusterers, lines, labels, only_high_p
276-
)
230+
if num_params > 1:
231+
fig, axes = plt.subplots(nrows=1, ncols=num_params, figsize=(25, 5))
232+
for param_idx, param in enumerate(dfs.keys()):
233+
df = dfs[param]
234+
graphs = df["Input Graph"].unique()
235+
clusterers = df["Clusterer Name"].unique()
236+
assert len(graphs)==1
237+
graph = graphs[0]
238+
239+
ax = axes[param_idx]
240+
241+
plotPRParetoAX(ax, f"{graph}_{param}", df, clusterers, lines, labels, only_high_p)
277242

278-
graph_idx += 1
279-
else:
280-
axes[i][j].axis("off") # Turn off axes without data
281243

282244
fig.legend(
283245
lines,
284246
labels,
285247
loc="upper center",
286248
ncol=ncol,
287-
bbox_to_anchor=(0.5, 1),
249+
bbox_to_anchor=(0.5, 1.15),
288250
frameon=False,
289251
)
252+
return axes
253+
290254
else:
291-
plt.rcParams.update({"font.size": 20})
292-
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(25, 5))
293-
for graph_idx in range(len(graphs)):
294-
graph = graphs[graph_idx]
295-
ax = axes[graph_idx]
296-
297-
plotPRParetoAX(ax, graph, df, clusterers, lines, labels, only_high_p)
298-
299-
graph_idx += 1
255+
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
256+
param = [k for k in dfs.keys()][0]
257+
df = dfs[param]
258+
graphs = df["Input Graph"].unique()
259+
clusterers = df["Clusterer Name"].unique()
260+
assert len(graphs)==1
261+
graph = graphs[0]
262+
plotPRParetoAX(ax, f"{graph}_{param}", df, clusterers, lines, labels, only_high_p)
300263

301264
fig.legend(
302265
lines,
@@ -305,8 +268,8 @@ def plotPRPareto(df, only_high_p=False, ncol=6):
305268
ncol=ncol,
306269
bbox_to_anchor=(0.5, 1.15),
307270
frameon=False,
308-
)
309-
return axes
271+
)
272+
return ax
310273

311274

312275
def plotPRParetoSingle(df, graph):
@@ -420,9 +383,9 @@ def getAUCTable(df, df_pr_pareto, print_table=False):
420383

421384

422385
def plot_ngrams():
423-
# df_pcbs = pd.read_csv(base_addr + f"out_ngrams_pcbs_csv/stats.csv")
386+
df_pcbs = pd.read_csv(base_addr + f"out_ngrams_pcbs_csv/stats.csv")
424387
df_pcbs_high_res = pd.read_csv(base_addr + f"out_ngrams_high_res_pcbs_csv/stats.csv")
425-
df = pd.concat([df_pcbs_high_res]) #df_pcbs,
388+
df = pd.concat([df_pcbs, df_pcbs_high_res])
426389

427390
df = df.dropna(how="all")
428391
replace_graph_names(df)
@@ -447,32 +410,46 @@ def plot_ngrams():
447410
"ParHACClusterer_1",
448411
]
449412

450-
451-
thresholds = [0.86, 0.88, 0.90, 0.92, 0.94]
452-
453-
for threshold in thresholds:
413+
def get_threshold_df(threshold):
454414
df_pcbs = df[df["Clusterer Name"].isin(our_methods)]
455415

456416
df_pcbs["fScore_mean"] = df["fScore_mean"].apply(lambda k: k[threshold])
457417
df_pcbs["communityPrecision_mean"] = df["communityPrecision_mean"].apply(lambda k: k[threshold])
458418
df_pcbs["communityRecall_mean"] = df["communityRecall_mean"].apply(lambda k: k[threshold])
419+
return df_pcbs
420+
421+
thresholds = [0.88, 0.90, 0.92, 0.94]
422+
df_pr_paretos = {}
423+
424+
for threshold in thresholds:
425+
df_pcbs = get_threshold_df(threshold)
459426

460-
# Get AUC table
461427
df_pr_pareto = FilterParetoPRMethod(df_pcbs)
462-
getAUCTable(df_pcbs, df_pr_pareto)
463-
464-
# Plot Precision Recall Pareto frontier for PCBS methods
465-
axes = plotPRPareto(df_pr_pareto, only_high_p=True) #
466-
plt.savefig(base_addr + f"pr_uci_{threshold}.pdf", bbox_inches="tight")
467-
print("plotted pr_uci.pdf")
468-
469-
# Plot F_0.5 runtime Pareto frontier for PCBS methods
470-
clusterers = df_pcbs["Clusterer Name"].unique()
471-
dfs, graphs = GetParetoDfs(df_pcbs)
472-
plotPareto(dfs, graphs, clusterers)
473-
plt.tight_layout()
474-
plt.savefig(base_addr + f"time_f1_uci_{threshold}.pdf", bbox_inches="tight")
475-
print("plotted time_f1_uci.pdf")
428+
df_pr_paretos[threshold] = df_pr_pareto
429+
430+
# Plot Precision Recall Pareto frontier for PCBS methods
431+
plotPRPareto(df_pr_paretos, only_high_p=True) #
432+
plt.savefig(base_addr + f"pr_uci.pdf", bbox_inches="tight")
433+
print("plotted pr_uci.pdf")
434+
435+
# plot single example
436+
threshold = 0.92
437+
df_pcbs = get_threshold_df(threshold)
438+
df_pr_pareto = FilterParetoPRMethod(df_pcbs)
439+
getAUCTable(df_pcbs, df_pr_pareto)
440+
ax = plotPRPareto({threshold:df_pr_pareto}, only_high_p=True, ncol=3)
441+
ax.set_title("")
442+
plt.savefig(base_addr + f"pr_uci_{threshold}.pdf", bbox_inches="tight")
443+
print(f"plotted pr_uci_{threshold}.pdf")
444+
445+
# Plot F_0.5 runtime Pareto frontier for PCBS methods
446+
clusterers = df_pcbs["Clusterer Name"].unique()
447+
dfs, graphs = GetParetoDfs(df_pcbs)
448+
ax = plotPareto(dfs, graphs, clusterers, draw_legend=False)
449+
ax.set_title("")
450+
plt.tight_layout()
451+
plt.savefig(base_addr + f"time_f1_uci_{threshold}.pdf", bbox_inches="tight")
452+
print(f"plotted time_f1_uci_{threshold}.pdf")
476453

477454
if __name__ == "__main__":
478455
base_addr = "results/"

0 commit comments

Comments
 (0)