5
5
6
6
def _config_str_to_dict (input_str ):
7
7
# e.g.
8
- # input_str = "precision_recall_pair_threshold : 0.92 ,f_score_param: 0.5"
9
- # result_dict = {'precision_recall_pair_threshold ': 0.92 , 'f_score_param': 0.5}
8
+ # input_str = "precision_recall_pair_thresholds : 0.86;0.88;0.90;0.92;0.94 ,f_score_param: 0.5"
9
+ # result_dict = {'precision_recall_pair_thresholds ': [0.86;0.88;0.90;0.92;0.94] , 'f_score_param': 0.5}
10
10
11
11
12
12
# Split the string into key-value pairs
@@ -24,12 +24,15 @@ def _config_str_to_dict(input_str):
24
24
value = value .strip ()
25
25
# Convert the value to float if possible
26
26
try :
27
- value = float (value )
27
+ if key == "precision_recall_pair_thresholds" :
28
+ # Split the value by semicolons and convert each to float
29
+ value = [float (v .strip ()) for v in value .split (';' )]
30
+ else :
31
+ value = float (value )
28
32
except ValueError :
29
33
pass # Keep the value as a string if it cannot be converted
30
34
# Add the key-value pair to the dictionary
31
35
result_dict [key ] = value
32
-
33
36
return result_dict
34
37
35
38
@@ -70,7 +73,7 @@ def read_ground_truth_pairs(ground_truth_file):
70
73
print (f"Ignoring invalid line: { line .strip ()} " )
71
74
return pairs
72
75
73
- def compute_precision_recall (node_to_clusters , pairs , threshold ):
76
+ def compute_precision_recall (node_to_clusters , pairs , thresholds , f_score_param ):
74
77
"""
75
78
Computes precision and recall based on the clusters and ground truth pairs.
76
79
Handles overlapping clusters where a node can belong to multiple clusters.
@@ -83,66 +86,74 @@ def compute_precision_recall(node_to_clusters, pairs, threshold):
83
86
FP = 0 # False Positive
84
87
FN = 0 # False Negative
85
88
86
- for node1 , node2 , weight in pairs :
87
- # Determine if the pair is positive or negative
88
- is_positive = weight > threshold
89
-
90
- # Determine if the nodes are in the same cluster (overlapping clusters)
91
- if node1 in node_to_clusters and node2 in node_to_clusters :
92
- clusters1 = node_to_clusters [node1 ]
93
- clusters2 = node_to_clusters [node2 ]
94
- in_same_cluster = bool (clusters1 & clusters2 ) # Check for non-empty intersection
95
- else :
96
- logging .warning ("skipping nodes %s, %s" , node1 , node2 )
97
- # Nodes not found in clusters; skip this pair
98
- continue
99
-
100
- if is_positive :
101
- if in_same_cluster :
102
- TP += 1
103
- else :
104
- FN += 1
105
- else :
106
- if not in_same_cluster :
107
- TN += 1
108
- else :
109
- FP += 1
110
-
111
- # Calculate precision and recall
112
- precision = TP / (TP + FP ) if (TP + FP ) > 0 else 0
113
- recall = TP / (TP + FN ) if (TP + FN ) > 0 else 0
114
-
115
- return precision , recall , TP , FP , TN , FN
89
+ precisions = {}
90
+ recalls = {}
91
+ f_scores = {}
92
+
93
+ for threshold in thresholds :
94
+ for node1 , node2 , weight in pairs :
95
+ # Determine if the pair is positive or negative
96
+ is_positive = weight > threshold
97
+
98
+ # Determine if the nodes are in the same cluster (overlapping clusters)
99
+ if node1 in node_to_clusters and node2 in node_to_clusters :
100
+ clusters1 = node_to_clusters [node1 ]
101
+ clusters2 = node_to_clusters [node2 ]
102
+ in_same_cluster = bool (clusters1 & clusters2 ) # Check for non-empty intersection
103
+ else :
104
+ logging .warning ("skipping nodes %s, %s" , node1 , node2 )
105
+ # Nodes not found in clusters; skip this pair
106
+ continue
107
+
108
+ if is_positive :
109
+ if in_same_cluster :
110
+ TP += 1
111
+ else :
112
+ FN += 1
113
+ else :
114
+ if not in_same_cluster :
115
+ TN += 1
116
+ else :
117
+ FP += 1
118
+
119
+ # Calculate precision and recall
120
+ precision = TP / (TP + FP ) if (TP + FP ) > 0 else 0
121
+ recall = TP / (TP + FN ) if (TP + FN ) > 0 else 0
122
+ f_score = 0
123
+ if precision != 0 and recall != 0 :
124
+ f_score = (1 + f_score_param * f_score_param ) * precision * recall / ((f_score_param * f_score_param * precision ) + recall )
125
+
126
+ precisions [threshold ] = precision
127
+ recalls [threshold ] = recall
128
+ f_scores [threshold ] = f_score
129
+
130
+ return precisions , recalls , f_scores
116
131
117
132
118
133
def compute_precision_recall_pair (in_clustering , input_communities , out_statistics , stats_config , stats_dict ):
119
134
"""
120
135
Compute pair precision and recall, and record result into stats_dict
121
136
"""
122
137
stats_config = _config_str_to_dict (stats_config )
123
- precision_recall_pair_threshold = stats_config ["precision_recall_pair_threshold " ]
138
+ precision_recall_pair_thresholds = stats_config ["precision_recall_pair_thresholds " ]
124
139
f_score_param = stats_config .get ("f_score_param" , 1 )
125
140
print ()
126
141
print ("clustering file" , in_clustering )
127
142
print ("community file" , input_communities )
128
143
print ("stats file" , out_statistics )
129
- print ("parameters, " , precision_recall_pair_threshold , f_score_param )
144
+ print ("parameters, " , precision_recall_pair_thresholds , f_score_param )
130
145
131
146
# Read clusters and ground truth pairs
132
147
clusters = read_clusters (in_clustering )
133
148
pairs = read_ground_truth_pairs (input_communities )
134
149
135
150
# Compute precision and recall
136
- precision , recall , TP , FP , TN , FN = compute_precision_recall (clusters , pairs , precision_recall_pair_threshold )
137
-
138
- f_score = 0
139
- if precision != 0 and recall != 0 :
140
- f_score = (1 + f_score_param * f_score_param ) * precision * recall / ((f_score_param * f_score_param * precision ) + recall )
151
+ precisions , recalls , f_scores = compute_precision_recall (clusters , pairs , precision_recall_pair_thresholds , f_score_param )
141
152
142
- stats_dict ["fScore_mean" ] = f_score
143
- stats_dict ["communityPrecision_mean" ] = precision
144
- stats_dict ["communityRecall_mean" ] = recall
145
- stats_dict ["PrecisionRecallPairThreshold " ] = precision_recall_pair_threshold
153
+ stats_dict ["fScore_mean" ] = f_scores
154
+ stats_dict ["communityPrecision_mean" ] = precisions
155
+ stats_dict ["communityRecall_mean" ] = recalls
156
+ stats_dict ["PrecisionRecallPairThresholds " ] = precision_recall_pair_thresholds
146
157
stats_dict ["fScoreParam" ] = f_score_param
147
158
148
159
with open (out_statistics , 'w' , encoding = 'utf-8' ) as f :
0 commit comments