Skip to content

Commit e76fc44

Browse files
committed
plotting
1 parent 491807c commit e76fc44

File tree

1 file changed

+41
-36
lines changed

1 file changed

+41
-36
lines changed

plotting/plot_pareto_ngrams.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -272,28 +272,6 @@ def plotPRPareto(dfs, only_high_p=False, ncol=6):
272272
return ax
273273

274274

275-
def plotPRParetoSingle(df, graph):
276-
plt.rcParams.update({"font.size": 20})
277-
clusterers = df["Clusterer Name"].unique()
278-
279-
lines = [] # To store the Line2D objects for the legend
280-
labels = [] # To store the corresponding labels for the Line2D objects
281-
282-
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
283-
284-
plotPRParetoAX(ax, graph, df, clusterers, lines, labels)
285-
286-
fig.subplots_adjust(hspace=0.4)
287-
fig.legend(
288-
lines,
289-
labels,
290-
loc="upper left",
291-
ncol=1,
292-
bbox_to_anchor=(0.9, 0.8),
293-
frameon=False,
294-
)
295-
296-
297275
# compute the area under the precision recall pareto curve, for precision >= 0.5.
298276
def computeAUC(df_pr_pareto, clusterer, graph):
299277
df = df_pr_pareto[
@@ -380,6 +358,46 @@ def getAUCTable(df, df_pr_pareto, print_table=False):
380358
print(latex_table)
381359

382360

361+
def plot_single_threshold(threshold, df_pcbs):
362+
graphs = df_pcbs["Input Graph"].unique()
363+
assert len(graphs)==1
364+
graph = graphs[0]
365+
366+
367+
clusterers = df_pcbs["Clusterer Name"].unique()
368+
df_pr_pareto = FilterParetoPRMethod(df_pcbs)
369+
getAUCTable(df_pcbs, df_pr_pareto)
370+
371+
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(16, 8))
372+
plt.rcParams.update({"font.size": 25})
373+
374+
lines = [] # To store the Line2D objects for the legend
375+
labels = [] # To store the corresponding labels for the Line2D objects
376+
plotPRParetoAX(axs[0], graph, df_pr_pareto, clusterers, lines, labels, only_high_p=True)
377+
378+
# Plot F_0.5 runtime Pareto frontier for PCBS methods
379+
dfs, graphs = GetParetoDfs(df_pcbs)
380+
plotParetoAxis(axs[1], dfs, graph, [], [], clusterers)
381+
382+
for ax in axs:
383+
ax.set_title("")
384+
ax.set_xlabel(ax.get_xlabel(), fontsize=25)
385+
ax.set_ylabel(ax.get_ylabel(), fontsize=25)
386+
ax.tick_params(axis='both', which='major', labelsize=25)
387+
388+
plt.tight_layout()
389+
fig.subplots_adjust(hspace=0.4)
390+
fig.legend(
391+
lines,
392+
labels,
393+
loc="upper center",
394+
ncol=4,
395+
bbox_to_anchor=(0.5, 1.2),
396+
frameon=False,
397+
)
398+
plt.savefig(base_addr + f"ngrams_{threshold}.pdf", bbox_inches="tight")
399+
print(f"plotted ngrams_{threshold}.pdf")
400+
383401
base_addr = "./results/"
384402

385403

@@ -436,21 +454,8 @@ def get_threshold_df(threshold):
436454
# plot single example
437455
threshold = 0.92
438456
df_pcbs = get_threshold_df(threshold)
439-
df_pr_pareto = FilterParetoPRMethod(df_pcbs)
440-
getAUCTable(df_pcbs, df_pr_pareto)
441-
ax = plotPRPareto({threshold:df_pr_pareto}, only_high_p=True, ncol=3)
442-
ax.set_title("")
443-
plt.savefig(base_addr + f"pr_ngrams_{threshold}.pdf", bbox_inches="tight")
444-
print(f"plotted pr_ngrams_{threshold}.pdf")
457+
plot_single_threshold(threshold, df_pcbs)
445458

446-
# Plot F_0.5 runtime Pareto frontier for PCBS methods
447-
clusterers = df_pcbs["Clusterer Name"].unique()
448-
dfs, graphs = GetParetoDfs(df_pcbs)
449-
ax = plotPareto(dfs, graphs, clusterers, draw_legend=False)
450-
ax.set_title("")
451-
plt.tight_layout()
452-
plt.savefig(base_addr + f"time_f1_ngrams_{threshold}.pdf", bbox_inches="tight")
453-
print(f"plotted time_f1_ngrams_{threshold}.pdf")
454459

455460
if __name__ == "__main__":
456461
base_addr = "results/"

0 commit comments

Comments
 (0)