Skip to content

Commit 7a6d09c

Browse files
committed
multiple thresholds
1 parent eee3e48 commit 7a6d09c

File tree

7 files changed

+95
-81
lines changed

7 files changed

+95
-81
lines changed

configs_experiments/ngrams/cluster_ngraphs_high_res_pcbs.config

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Input directory: /home/sy/ParClusterers/ngrams_graphs/NGramsGraphs/
22
Output directory: /home/sy/ParClusterers/results/out_ngrams_high_res_pcbs/
3-
CSV output directory: /home/sy/ParClusterers/results/out_ngrams_high_res_0.92_pcbs_csv/
3+
CSV output directory: /home/sy/ParClusterers/results/out_ngrams_high_res_pcbs_csv/
44
Clusterers: LDDClusterer;TectonicClusterer;ParallelAffinityClusterer;ParHacClusterer;ParallelCorrelationClusterer;ParallelModularityClusterer
55
Graphs: ngrams.graph.gbbs
66
GBBS format: true

configs_experiments/ngrams/cluster_ngraphs_pcbs.config

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Input directory: /home/sy/ParClusterers/ngrams_graphs/NGramsGraphs/
22
Output directory: /home/sy/ParClusterers/results/out_ngrams_pcbs/
3-
CSV output directory: /home/sy/ParClusterers/results/out_ngrams_0.88_pcbs_csv/
3+
CSV output directory: /home/sy/ParClusterers/results/out_ngrams_pcbs_csv/
44
Clusterers: LDDClusterer;ScanClusterer;LabelPropagationClusterer;SLPAClusterer;TectonicClusterer;ConnectivityClusterer;ParallelAffinityClusterer;ParHacClusterer;ParallelCorrelationClusterer;ParallelModularityClusterer
55
Graphs: ngrams.graph.gbbs
66
GBBS format: true

configs_experiments/ngrams/stats_ngrams.config

Lines changed: 0 additions & 8 deletions
This file was deleted.

configs_experiments/ngrams/stats_pair_ngrams.config

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@ Input communities: clusters.pair.cmty
22
Deterministic: false
33

44
statistics_config:
5-
precision_recall_pair_threshold: 0.92
5+
precision_recall_pair_thresholds: 0.86;0.88;0.90;0.92;0.94
66
f_score_param: 0.5

plotting/plot_pareto_ngrams.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import numpy as np
88
from plotting_utils import *
99

10+
import ast
11+
1012
plt.rcParams["ps.useafm"] = True
1113
plt.rcParams["pdf.use14corefonts"] = True
1214
# plt.rcParams["text.usetex"] = True
@@ -418,14 +420,16 @@ def getAUCTable(df, df_pr_pareto, print_table=False):
418420

419421

420422
def plot_ngrams():
421-
threshold = 0.92
422-
df_pcbs = pd.read_csv(base_addr + f"out_ngrams_{threshold}_pcbs_csv/stats.csv")
423-
df_pcbs_high_res = pd.read_csv(base_addr + f"out_ngrams_high_res_{threshold}_pcbs_csv/stats.csv")
424-
df = pd.concat([df_pcbs, df_pcbs_high_res])
423+
# df_pcbs = pd.read_csv(base_addr + f"out_ngrams_pcbs_csv/stats.csv")
424+
df_pcbs_high_res = pd.read_csv(base_addr + f"out_ngrams_high_res_pcbs_csv/stats.csv")
425+
df = pd.concat([df_pcbs_high_res]) #df_pcbs,
425426

426427
df = df.dropna(how="all")
427428
replace_graph_names(df)
428429
df = add_epsilon_to_hac(df)
430+
df["fScore_mean"] = df["fScore_mean"].apply(ast.literal_eval)
431+
df["communityPrecision_mean"] = df["communityPrecision_mean"].apply(ast.literal_eval)
432+
df["communityRecall_mean"] = df["communityRecall_mean"].apply(ast.literal_eval)
429433

430434
our_methods = [
431435
"KCoreClusterer",
@@ -443,25 +447,32 @@ def plot_ngrams():
443447
"ParHACClusterer_1",
444448
]
445449

446-
df_pcbs = df[df["Clusterer Name"].isin(our_methods)]
447-
448-
# Get AUC table
449-
df_pr_pareto = FilterParetoPRMethod(df_pcbs)
450-
getAUCTable(df_pcbs, df_pr_pareto)
451-
452-
# Plot Precision Recall Pareto frontier for PCBS methods
453-
axes = plotPRPareto(df_pr_pareto, only_high_p=True)
454-
axes[0].set_ylim((0.5, 0.8))
455-
plt.savefig(base_addr + f"pr_uci_{threshold}.pdf", bbox_inches="tight")
456-
print("plotted pr_uci.pdf")
457-
458-
# Plot F_0.5 runtime Pareto frontier for PCBS methods
459-
clusterers = df_pcbs["Clusterer Name"].unique()
460-
dfs, graphs = GetParetoDfs(df_pcbs)
461-
plotPareto(dfs, graphs, clusterers)
462-
plt.tight_layout()
463-
plt.savefig(base_addr + f"time_f1_uci_{threshold}.pdf", bbox_inches="tight")
464-
print("plotted time_f1_uci.pdf")
450+
451+
thresholds = [0.86, 0.88, 0.90, 0.92, 0.94]
452+
453+
for threshold in thresholds:
454+
df_pcbs = df[df["Clusterer Name"].isin(our_methods)]
455+
456+
df_pcbs["fScore_mean"] = df["fScore_mean"].apply(lambda k: k[threshold])
457+
df_pcbs["communityPrecision_mean"] = df["communityPrecision_mean"].apply(lambda k: k[threshold])
458+
df_pcbs["communityRecall_mean"] = df["communityRecall_mean"].apply(lambda k: k[threshold])
459+
460+
# Get AUC table
461+
df_pr_pareto = FilterParetoPRMethod(df_pcbs)
462+
getAUCTable(df_pcbs, df_pr_pareto)
463+
464+
# Plot Precision Recall Pareto frontier for PCBS methods
465+
axes = plotPRPareto(df_pr_pareto, only_high_p=True) #
466+
plt.savefig(base_addr + f"pr_uci_{threshold}.pdf", bbox_inches="tight")
467+
print("plotted pr_uci.pdf")
468+
469+
# Plot F_0.5 runtime Pareto frontier for PCBS methods
470+
clusterers = df_pcbs["Clusterer Name"].unique()
471+
dfs, graphs = GetParetoDfs(df_pcbs)
472+
plotPareto(dfs, graphs, clusterers)
473+
plt.tight_layout()
474+
plt.savefig(base_addr + f"time_f1_uci_{threshold}.pdf", bbox_inches="tight")
475+
print("plotted time_f1_uci.pdf")
465476

466477
if __name__ == "__main__":
467478
base_addr = "results/"

stats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def runStats(out_prefix, graph, graph_idx, stats_dict):
5252
return
5353
use_input_graph = runner_utils.input_directory + graph
5454
input_communities = runner_utils.input_directory + runner_utils.communities[graph_idx]
55-
if "precision_recall_pair_threshold" in runner_utils.stats_config:
55+
if "precision_recall_pair_thresholds" in runner_utils.stats_config:
5656
compute_precision_recall_pair(in_clustering, input_communities, out_statistics_pair, runner_utils.stats_config, stats_dict)
5757
return
5858
use_input_communities = "" if not runner_utils.communities else "--input_communities=" + input_communities

stats_precision_recall_pair.py

Lines changed: 57 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
def _config_str_to_dict(input_str):
77
# e.g.
8-
# input_str = "precision_recall_pair_threshold: 0.92,f_score_param: 0.5"
9-
# result_dict = {'precision_recall_pair_threshold': 0.92, 'f_score_param': 0.5}
8+
# input_str = "precision_recall_pair_thresholds: 0.86;0.88;0.90;0.92;0.94,f_score_param: 0.5"
9+
# result_dict = {'precision_recall_pair_thresholds': [0.86;0.88;0.90;0.92;0.94], 'f_score_param': 0.5}
1010

1111

1212
# Split the string into key-value pairs
@@ -24,12 +24,15 @@ def _config_str_to_dict(input_str):
2424
value = value.strip()
2525
# Convert the value to float if possible
2626
try:
27-
value = float(value)
27+
if key == "precision_recall_pair_thresholds":
28+
# Split the value by semicolons and convert each to float
29+
value = [float(v.strip()) for v in value.split(';')]
30+
else:
31+
value = float(value)
2832
except ValueError:
2933
pass # Keep the value as a string if it cannot be converted
3034
# Add the key-value pair to the dictionary
3135
result_dict[key] = value
32-
3336
return result_dict
3437

3538

@@ -70,7 +73,7 @@ def read_ground_truth_pairs(ground_truth_file):
7073
print(f"Ignoring invalid line: {line.strip()}")
7174
return pairs
7275

73-
def compute_precision_recall(node_to_clusters, pairs, threshold):
76+
def compute_precision_recall(node_to_clusters, pairs, thresholds, f_score_param):
7477
"""
7578
Computes precision and recall based on the clusters and ground truth pairs.
7679
Handles overlapping clusters where a node can belong to multiple clusters.
@@ -83,66 +86,74 @@ def compute_precision_recall(node_to_clusters, pairs, threshold):
8386
FP = 0 # False Positive
8487
FN = 0 # False Negative
8588

86-
for node1, node2, weight in pairs:
87-
# Determine if the pair is positive or negative
88-
is_positive = weight > threshold
89-
90-
# Determine if the nodes are in the same cluster (overlapping clusters)
91-
if node1 in node_to_clusters and node2 in node_to_clusters:
92-
clusters1 = node_to_clusters[node1]
93-
clusters2 = node_to_clusters[node2]
94-
in_same_cluster = bool(clusters1 & clusters2) # Check for non-empty intersection
95-
else:
96-
logging.warning("skipping nodes %s, %s", node1, node2)
97-
# Nodes not found in clusters; skip this pair
98-
continue
99-
100-
if is_positive:
101-
if in_same_cluster:
102-
TP += 1
103-
else:
104-
FN += 1
105-
else:
106-
if not in_same_cluster:
107-
TN += 1
108-
else:
109-
FP += 1
110-
111-
# Calculate precision and recall
112-
precision = TP / (TP + FP) if (TP + FP) > 0 else 0
113-
recall = TP / (TP + FN) if (TP + FN) > 0 else 0
114-
115-
return precision, recall, TP, FP, TN, FN
89+
precisions = {}
90+
recalls = {}
91+
f_scores = {}
92+
93+
for threshold in thresholds:
94+
for node1, node2, weight in pairs:
95+
# Determine if the pair is positive or negative
96+
is_positive = weight > threshold
97+
98+
# Determine if the nodes are in the same cluster (overlapping clusters)
99+
if node1 in node_to_clusters and node2 in node_to_clusters:
100+
clusters1 = node_to_clusters[node1]
101+
clusters2 = node_to_clusters[node2]
102+
in_same_cluster = bool(clusters1 & clusters2) # Check for non-empty intersection
103+
else:
104+
logging.warning("skipping nodes %s, %s", node1, node2)
105+
# Nodes not found in clusters; skip this pair
106+
continue
107+
108+
if is_positive:
109+
if in_same_cluster:
110+
TP += 1
111+
else:
112+
FN += 1
113+
else:
114+
if not in_same_cluster:
115+
TN += 1
116+
else:
117+
FP += 1
118+
119+
# Calculate precision and recall
120+
precision = TP / (TP + FP) if (TP + FP) > 0 else 0
121+
recall = TP / (TP + FN) if (TP + FN) > 0 else 0
122+
f_score = 0
123+
if precision !=0 and recall != 0:
124+
f_score = (1 + f_score_param * f_score_param) * precision * recall / ((f_score_param * f_score_param * precision) + recall)
125+
126+
precisions[threshold] = precision
127+
recalls[threshold] = recall
128+
f_scores[threshold] = f_score
129+
130+
return precisions, recalls, f_scores
116131

117132

118133
def compute_precision_recall_pair(in_clustering, input_communities, out_statistics, stats_config, stats_dict):
119134
"""
120135
Compute pair precision and recall, and record result into stats_dict
121136
"""
122137
stats_config = _config_str_to_dict(stats_config)
123-
precision_recall_pair_threshold = stats_config["precision_recall_pair_threshold"]
138+
precision_recall_pair_thresholds = stats_config["precision_recall_pair_thresholds"]
124139
f_score_param = stats_config.get("f_score_param", 1)
125140
print()
126141
print("clustering file", in_clustering)
127142
print("community file", input_communities)
128143
print("stats file", out_statistics)
129-
print("parameters, ", precision_recall_pair_threshold, f_score_param)
144+
print("parameters, ", precision_recall_pair_thresholds, f_score_param)
130145

131146
# Read clusters and ground truth pairs
132147
clusters = read_clusters(in_clustering)
133148
pairs = read_ground_truth_pairs(input_communities)
134149

135150
# Compute precision and recall
136-
precision, recall, TP, FP, TN, FN = compute_precision_recall(clusters, pairs, precision_recall_pair_threshold)
137-
138-
f_score = 0
139-
if precision !=0 and recall != 0:
140-
f_score = (1 + f_score_param * f_score_param) * precision * recall / ((f_score_param * f_score_param * precision) + recall)
151+
precisions, recalls, f_scores = compute_precision_recall(clusters, pairs, precision_recall_pair_thresholds, f_score_param)
141152

142-
stats_dict["fScore_mean"] = f_score
143-
stats_dict["communityPrecision_mean"] = precision
144-
stats_dict["communityRecall_mean"] = recall
145-
stats_dict["PrecisionRecallPairThreshold"] = precision_recall_pair_threshold
153+
stats_dict["fScore_mean"] = f_scores
154+
stats_dict["communityPrecision_mean"] = precisions
155+
stats_dict["communityRecall_mean"] = recalls
156+
stats_dict["PrecisionRecallPairThresholds"] = precision_recall_pair_thresholds
146157
stats_dict["fScoreParam"] = f_score_param
147158

148159
with open(out_statistics, 'w', encoding='utf-8') as f:

0 commit comments

Comments
 (0)