Skip to content

Commit eee3e48

Browse files
committed
add pair precision recall stats
1 parent f065785 commit eee3e48

File tree

4 files changed

+164
-2
lines changed

4 files changed

+164
-2
lines changed

configs_experiments/ngrams/cluster_ngraphs_high_res_pcbs.config

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,5 @@ ParallelAffinityClusterer:
4242

4343
ParHacClusterer:
4444
parhac_clusterer_config:
45-
weight_threshold: 0.5; 0.7320566343659267; 0.8564127056253706; 0.8833708760578991; 0.9052677145931002; 0.9230534741659427
45+
weight_threshold: 0.5; 0.7320566343659267; 0.82; 0.84; 0.86; 0.88; 0.9; 0.92; 0.94; 0.96; 0.98
4646
epsilon: 0.01;0.1;1
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Input communities: clusters.pair.cmty
2+
Deterministic: false
3+
4+
statistics_config:
5+
precision_recall_pair_threshold: 0.92
6+
f_score_param: 0.5

stats.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import json
1010
import pandas as pd
1111

12+
from stats_precision_recall_pair import compute_precision_recall_pair
13+
1214
def getRunTime(clusterer, out_prefix):
1315
cluster_time = -1
1416
out_filename = out_prefix + ".out"
@@ -42,13 +44,18 @@ def getRunTime(clusterer, out_prefix):
4244

4345
def runStats(out_prefix, graph, graph_idx, stats_dict):
4446
out_statistics = out_prefix + ".stats"
47+
out_statistics_pair = out_prefix + ".pair.stats"
4548
in_clustering = out_prefix + ".cluster"
4649
if not os.path.exists(in_clustering) or not os.path.getsize(in_clustering) > 0:
4750
# Either an error or a timeout happened
4851
runner_utils.appendToFile("ERROR", out_statistics)
4952
return
5053
use_input_graph = runner_utils.input_directory + graph
51-
use_input_communities = "" if not runner_utils.communities else "--input_communities=" + runner_utils.input_directory + runner_utils.communities[graph_idx]
54+
input_communities = runner_utils.input_directory + runner_utils.communities[graph_idx]
55+
if "precision_recall_pair_threshold" in runner_utils.stats_config:
56+
compute_precision_recall_pair(in_clustering, input_communities, out_statistics_pair, runner_utils.stats_config, stats_dict)
57+
return
58+
use_input_communities = "" if not runner_utils.communities else "--input_communities=" + input_communities
5259
ss = ("bazel run //clusterers:stats-in-memory_main -- "
5360
"--input_graph=" + use_input_graph + " "
5461
"--is_gbbs_format=" + runner_utils.gbbs_format + " "

stats_precision_recall_pair.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
2+
import logging
3+
import sys
4+
import json
5+
6+
def _config_str_to_dict(input_str):
7+
# e.g.
8+
# input_str = "precision_recall_pair_threshold: 0.92,f_score_param: 0.5"
9+
# result_dict = {'precision_recall_pair_threshold': 0.92, 'f_score_param': 0.5}
10+
11+
12+
# Split the string into key-value pairs
13+
pairs = input_str.split(',')
14+
15+
# Initialize an empty dictionary
16+
result_dict = {}
17+
18+
# Process each key-value pair
19+
for pair in pairs:
20+
# Split the pair by the colon
21+
key, value = pair.split(':')
22+
# Remove any leading/trailing whitespace
23+
key = key.strip()
24+
value = value.strip()
25+
# Convert the value to float if possible
26+
try:
27+
value = float(value)
28+
except ValueError:
29+
pass # Keep the value as a string if it cannot be converted
30+
# Add the key-value pair to the dictionary
31+
result_dict[key] = value
32+
33+
return result_dict
34+
35+
36+
def read_clusters(cluster_file):
37+
"""
38+
Reads the clusters from a file and returns a dictionary mapping node IDs to a set of cluster IDs.
39+
"""
40+
node_to_clusters = {}
41+
with open(cluster_file, 'r') as f:
42+
for cluster_id, line in enumerate(f):
43+
nodes = line.strip().split("\t")
44+
for node in nodes:
45+
node = node.strip()
46+
if node:
47+
if node not in node_to_clusters:
48+
node_to_clusters[node] = set()
49+
node_to_clusters[node].add(cluster_id)
50+
return node_to_clusters
51+
52+
def read_ground_truth_pairs(ground_truth_file):
53+
"""
54+
Reads the ground truth pairs from a file and returns a list of tuples (node1, node2, weight).
55+
"""
56+
pairs = []
57+
with open(ground_truth_file, 'r') as f:
58+
for line in f:
59+
parts = line.strip().split('\t')
60+
if len(parts) >= 3:
61+
node1 = parts[0].strip()
62+
node2 = parts[1].strip()
63+
try:
64+
weight = float(parts[2].strip())
65+
except ValueError:
66+
print(f"Invalid weight '{parts[2]}' in line: {line.strip()}")
67+
continue
68+
pairs.append((node1, node2, weight))
69+
else:
70+
print(f"Ignoring invalid line: {line.strip()}")
71+
return pairs
72+
73+
def compute_precision_recall(node_to_clusters, pairs, threshold):
74+
"""
75+
Computes precision and recall based on the clusters and ground truth pairs.
76+
Handles overlapping clusters where a node can belong to multiple clusters.
77+
78+
node_to_clusters: map from node id to a set of clusters
79+
pairs: list of (u,v,w) triplets
80+
"""
81+
TP = 0 # True Positive
82+
TN = 0 # True Negative
83+
FP = 0 # False Positive
84+
FN = 0 # False Negative
85+
86+
for node1, node2, weight in pairs:
87+
# Determine if the pair is positive or negative
88+
is_positive = weight > threshold
89+
90+
# Determine if the nodes are in the same cluster (overlapping clusters)
91+
if node1 in node_to_clusters and node2 in node_to_clusters:
92+
clusters1 = node_to_clusters[node1]
93+
clusters2 = node_to_clusters[node2]
94+
in_same_cluster = bool(clusters1 & clusters2) # Check for non-empty intersection
95+
else:
96+
logging.warning("skipping nodes %s, %s", node1, node2)
97+
# Nodes not found in clusters; skip this pair
98+
continue
99+
100+
if is_positive:
101+
if in_same_cluster:
102+
TP += 1
103+
else:
104+
FN += 1
105+
else:
106+
if not in_same_cluster:
107+
TN += 1
108+
else:
109+
FP += 1
110+
111+
# Calculate precision and recall
112+
precision = TP / (TP + FP) if (TP + FP) > 0 else 0
113+
recall = TP / (TP + FN) if (TP + FN) > 0 else 0
114+
115+
return precision, recall, TP, FP, TN, FN
116+
117+
118+
def compute_precision_recall_pair(in_clustering, input_communities, out_statistics, stats_config, stats_dict):
119+
"""
120+
Compute pair precision and recall, and record result into stats_dict
121+
"""
122+
stats_config = _config_str_to_dict(stats_config)
123+
precision_recall_pair_threshold = stats_config["precision_recall_pair_threshold"]
124+
f_score_param = stats_config.get("f_score_param", 1)
125+
print()
126+
print("clustering file", in_clustering)
127+
print("community file", input_communities)
128+
print("stats file", out_statistics)
129+
print("parameters, ", precision_recall_pair_threshold, f_score_param)
130+
131+
# Read clusters and ground truth pairs
132+
clusters = read_clusters(in_clustering)
133+
pairs = read_ground_truth_pairs(input_communities)
134+
135+
# Compute precision and recall
136+
precision, recall, TP, FP, TN, FN = compute_precision_recall(clusters, pairs, precision_recall_pair_threshold)
137+
138+
f_score = 0
139+
if precision !=0 and recall != 0:
140+
f_score = (1 + f_score_param * f_score_param) * precision * recall / ((f_score_param * f_score_param * precision) + recall)
141+
142+
stats_dict["fScore_mean"] = f_score
143+
stats_dict["communityPrecision_mean"] = precision
144+
stats_dict["communityRecall_mean"] = recall
145+
stats_dict["PrecisionRecallPairThreshold"] = precision_recall_pair_threshold
146+
stats_dict["fScoreParam"] = f_score_param
147+
148+
with open(out_statistics, 'w', encoding='utf-8') as f:
149+
json.dump(stats_dict, f, indent=4)

0 commit comments

Comments
 (0)