Skip to content

Commit 65e754c

Browse files
committed
plotting
1 parent e76fc44 commit 65e754c

File tree

3 files changed

+733
-0
lines changed

3 files changed

+733
-0
lines changed

plotting/plot_pareto_unweighted.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
import math
2+
import pandas as pd
3+
import seaborn as sns
4+
import matplotlib
5+
import matplotlib.pyplot as plt
6+
import re
7+
import numpy as np
8+
import plotting_utils
9+
from plotting_utils import *
10+
11+
plt.rcParams["ps.useafm"] = True
12+
plt.rcParams["pdf.use14corefonts"] = True
13+
plt.rcParams["text.usetex"] = True
14+
# sudo apt-get install texlive-latex-base texlive-fonts-recommended texlive-fonts-extra texlive-latex-extra
15+
fontsize = 25
16+
17+
18+
def plotParetoAxis(ax, dfs, graph, lines, labels, clusterers):
19+
for clusterer in clusterers:
20+
# Extract the pareto_df for the current graph and clusterer combination
21+
_, pareto_df = dfs[(graph, clusterer)]
22+
if pareto_df.empty:
23+
# print(graph, clusterer)
24+
continue
25+
26+
# Plot the pareto_df with the appropriate marker
27+
(line,) = ax.plot(
28+
pareto_df["Cluster Time"],
29+
pareto_df["fScore_mean"],
30+
label=clusterer,
31+
color=color_map[clusterer],
32+
marker=style_map[clusterer],
33+
markersize=16,
34+
linewidth=2,
35+
)
36+
37+
shortened_clusterer = clusterer.replace("Clusterer", "")
38+
# If the clusterer's line hasn't been added to lines, add it
39+
if shortened_clusterer not in labels:
40+
lines.append(line)
41+
labels.append(shortened_clusterer)
42+
43+
ax.set_xscale("log")
44+
ax.set_title(f"{graph}")
45+
ax.set_xlabel("Clustering Time")
46+
ax.set_ylabel("Mean $F_{0.5}$ Score")
47+
48+
49+
def plotPareto(dfs, graphs, clusterers, draw_legend=True, ncol=6):
50+
num_graphs = len(graphs)
51+
plt.rcParams.update({"font.size": 25})
52+
fig = None
53+
54+
if num_graphs == 6:
55+
plt.rcParams.update({'font.size': 25})
56+
57+
# Create subplots in a 2x3 grid
58+
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(22, 15))
59+
graph_idx = 0
60+
61+
lines = [] # To store the Line2D objects for the legend
62+
labels = [] # To store the corresponding labels for the Line2D objects
63+
64+
for i in range(2):
65+
for j in range(3):
66+
if graph_idx < len(graphs): # Ensure we have a graph to process
67+
graph = graphs[graph_idx]
68+
ax = axes[i][j]
69+
plotParetoAxis(ax, dfs, graph, lines, labels, clusterers)
70+
graph_idx += 1
71+
else:
72+
axes[i][j].axis('off') # Turn off axes without data
73+
# Create a single legend for the entire figure, at the top
74+
fig.legend(lines, labels, loc='upper center', ncol=ncol, bbox_to_anchor=(0.5, 1.1), frameon=False)
75+
else:
76+
fig, axes = plt.subplots(nrows=1, ncols=num_graphs, figsize=(16, 8))
77+
graph_idx = 0
78+
79+
lines = [] # To store the Line2D objects for the legend
80+
labels = [] # To store the corresponding labels for the Line2D objects
81+
82+
for graph_idx in range(num_graphs):
83+
graph = graphs[graph_idx]
84+
ax = axes[graph_idx]
85+
plotParetoAxis(ax, dfs, graph, lines, labels, clusterers)
86+
graph_idx += 1
87+
88+
if draw_legend:
89+
# Create a single legend for the entire figure, at the top
90+
fig.legend(
91+
lines,
92+
labels,
93+
loc="upper center",
94+
ncol=ncol,
95+
bbox_to_anchor=(0.5, 1.15),
96+
frameon=False,
97+
)
98+
return fig
99+
100+
101+
def plotPRParetoAX(ax, graph, df, clusterers, lines, labels, only_high_p=False):
102+
for clusterer in clusterers:
103+
# Extract the pareto_df for the current graph and clusterer combination
104+
pareto_df = df[
105+
(df["Clusterer Name"] == clusterer) & (df["Input Graph"] == graph)
106+
]
107+
if pareto_df.empty:
108+
continue
109+
110+
# Plot the pareto_df with the appropriate marker
111+
(line,) = ax.plot(
112+
pareto_df["communityPrecision_mean"],
113+
pareto_df["communityRecall_mean"],
114+
label=clusterer,
115+
color=color_map[clusterer],
116+
marker=style_map[clusterer],
117+
markersize=16,
118+
linewidth=2,
119+
)
120+
121+
shortened_clusterer = clusterer.replace("Clusterer", "")
122+
# If the clusterer's line hasn't been added to lines, add it
123+
if shortened_clusterer not in labels:
124+
lines.append(line)
125+
labels.append(shortened_clusterer)
126+
127+
ax.set_title(f"{graph}")
128+
ax.set_xlabel("Precision")
129+
ax.set_ylabel("Recall")
130+
if only_high_p:
131+
ax.set_xlim((0.5, 1))
132+
133+
134+
def plotPRPareto(df, only_high_p=False, ncol=6):
135+
graphs = df["Input Graph"].unique()
136+
clusterers = df["Clusterer Name"].unique()
137+
138+
graph_idx = 0
139+
140+
lines = [] # To store the Line2D objects for the legend
141+
labels = [] # To store the corresponding labels for the Line2D objects
142+
143+
plt.rcParams.update({"font.size": 25})
144+
fig, axes = plt.subplots(nrows=1, ncols=len(graphs), figsize=(16, 8))
145+
for graph_idx in range(len(graphs)):
146+
graph = graphs[graph_idx]
147+
ax = axes[graph_idx]
148+
149+
plotPRParetoAX(ax, graph, df, clusterers, lines, labels, only_high_p)
150+
151+
graph_idx += 1
152+
153+
plt.tight_layout()
154+
fig.subplots_adjust(hspace=0.4)
155+
fig.legend(
156+
lines,
157+
labels,
158+
loc="upper center",
159+
ncol=4,
160+
bbox_to_anchor=(0.5, 1.2),
161+
frameon=False,
162+
)
163+
return axes
164+
165+
166+
def load_all_dfs():
167+
df = pd.read_csv(base_addr + "snap_results/stats_snap_mod.csv")
168+
df2 = pd.read_csv(base_addr + "snap_results/stats_snap_ours.csv")
169+
df3 = pd.read_csv(base_addr + "snap_results/stats_snap_more.csv")
170+
df4 = pd.read_csv(base_addr + "snap_results/stats_snap_more_2.csv")
171+
172+
173+
df_neo4j = pd.read_csv(base_addr + "snap_results/stats_snap_neo4j.csv")
174+
df_neo4j_more = pd.read_csv(base_addr + "snap_results/stats_snap_neo4j_more.csv")
175+
df_neo4j = pd.concat([df_neo4j, df_neo4j_more])
176+
177+
df_nk = pd.read_csv(base_addr + "snap_results/stats_snap_nk.csv")
178+
df_nk_more = pd.read_csv(base_addr + "snap_results/stats_snap_nk_more.csv")
179+
df_nk = pd.concat([df_nk, df_nk_more])
180+
181+
182+
df_tg = pd.read_csv(base_addr + "snap_results/stats_snap_tg.csv")
183+
184+
df = pd.concat([df, df2, df3, df4])
185+
df = df[df["Clusterer Name"] != "ConnectivityClusterer"]
186+
187+
df = df.dropna(how='all')
188+
replace_graph_names(df)
189+
replace_graph_names(df_neo4j)
190+
replace_graph_names(df_nk)
191+
replace_graph_names(df_tg)
192+
193+
194+
df = add_epsilon_to_hac(df)
195+
196+
df_all = pd.concat([df, df_neo4j, df_nk])
197+
198+
df_compare = pd.concat([df[df["Clusterer Name"].isin(["ParallelModularityClusterer",
199+
"ParallelCorrelationClusterer"])], df_neo4j, df_nk, df_tg])
200+
return df_all, df_compare
201+
202+
203+
def plot_weighted():
204+
df, df_compare = load_all_dfs()
205+
our_methods = plotting_utils.get_our_methods()
206+
207+
datasets = ["LJ", "FS"]
208+
subset = "_subset"
209+
210+
df_subset = df[df["Input Graph"].isin(datasets)]
211+
df_ours = df_subset[df_subset["Clusterer Name"].isin(our_methods)]
212+
213+
df_pr_pareto = FilterParetoPRMethod(df_ours)
214+
# getAUCTable(df_ours, df_pr_pareto)
215+
axes = plotPRPareto(df_pr_pareto, True, ncol=6)
216+
axes[1].set_ylim((0, 0.8))
217+
plt.savefig(f"./results/pr_snap{subset}.pdf", bbox_inches='tight')
218+
print(f"plotted ./results/pr_snap{subset}.pdf")
219+
220+
clusterers = df_ours["Clusterer Name"].unique()
221+
dfs, graphs = GetParetoDfs(df_ours)
222+
plotPareto(dfs, graphs, clusterers, False)
223+
plt.tight_layout()
224+
plt.savefig(f"./results/time_f1_snap{subset}.pdf", bbox_inches='tight')
225+
print(f"plotted ./results/time_f1_snap{subset}.pdf")
226+
227+
df_subset = df_compare[df_compare["Input Graph"].isin(["LJ", "OK"])]
228+
229+
df_pr_pareto = FilterParetoPRMethod(df_subset)
230+
plotPRPareto(df_pr_pareto)
231+
plt.savefig(f"./results/pr_snap_modularity{subset}.pdf", bbox_inches='tight')
232+
print(f"plotted ./results/pr_snap_modularity{subset}.pdf")
233+
234+
clusterers = df_subset["Clusterer Name"].unique()
235+
dfs, graphs = GetParetoDfs(df_subset)
236+
plotPareto(dfs, graphs, clusterers, True, ncol=4)
237+
plt.tight_layout()
238+
plt.savefig(f"./results/time_f1_snap_modularity{subset}.pdf", bbox_inches='tight')
239+
print(f"plotted ./results/time_f1_snap_modularity{subset}.pdf")
240+
241+
clusterers = df_compare["Clusterer Name"].unique()
242+
dfs, graphs = GetParetoDfs(df_compare)
243+
plotPareto(dfs, graphs, clusterers)
244+
plt.tight_layout()
245+
plt.savefig(f"./results/time_f1_snap_modularity.pdf", bbox_inches='tight')
246+
print(f"plotted ./results/time_f1_snap_modularity.pdf")
247+
248+
if __name__ == "__main__":
249+
base_addr = "/Users/sy/Desktop/MIT/clusterer/csv/"
250+
plot_weighted()

0 commit comments

Comments
 (0)