1
+
2
+ import logging
3
+ import sys
4
+ import json
5
+
6
+ def _config_str_to_dict (input_str ):
7
+ # e.g.
8
+ # input_str = "precision_recall_pair_threshold: 0.92,f_score_param: 0.5"
9
+ # result_dict = {'precision_recall_pair_threshold': 0.92, 'f_score_param': 0.5}
10
+
11
+
12
+ # Split the string into key-value pairs
13
+ pairs = input_str .split (',' )
14
+
15
+ # Initialize an empty dictionary
16
+ result_dict = {}
17
+
18
+ # Process each key-value pair
19
+ for pair in pairs :
20
+ # Split the pair by the colon
21
+ key , value = pair .split (':' )
22
+ # Remove any leading/trailing whitespace
23
+ key = key .strip ()
24
+ value = value .strip ()
25
+ # Convert the value to float if possible
26
+ try :
27
+ value = float (value )
28
+ except ValueError :
29
+ pass # Keep the value as a string if it cannot be converted
30
+ # Add the key-value pair to the dictionary
31
+ result_dict [key ] = value
32
+
33
+ return result_dict
34
+
35
+
36
+ def read_clusters (cluster_file ):
37
+ """
38
+ Reads the clusters from a file and returns a dictionary mapping node IDs to a set of cluster IDs.
39
+ """
40
+ node_to_clusters = {}
41
+ with open (cluster_file , 'r' ) as f :
42
+ for cluster_id , line in enumerate (f ):
43
+ nodes = line .strip ().split ("\t " )
44
+ for node in nodes :
45
+ node = node .strip ()
46
+ if node :
47
+ if node not in node_to_clusters :
48
+ node_to_clusters [node ] = set ()
49
+ node_to_clusters [node ].add (cluster_id )
50
+ return node_to_clusters
51
+
52
+ def read_ground_truth_pairs (ground_truth_file ):
53
+ """
54
+ Reads the ground truth pairs from a file and returns a list of tuples (node1, node2, weight).
55
+ """
56
+ pairs = []
57
+ with open (ground_truth_file , 'r' ) as f :
58
+ for line in f :
59
+ parts = line .strip ().split ('\t ' )
60
+ if len (parts ) >= 3 :
61
+ node1 = parts [0 ].strip ()
62
+ node2 = parts [1 ].strip ()
63
+ try :
64
+ weight = float (parts [2 ].strip ())
65
+ except ValueError :
66
+ print (f"Invalid weight '{ parts [2 ]} ' in line: { line .strip ()} " )
67
+ continue
68
+ pairs .append ((node1 , node2 , weight ))
69
+ else :
70
+ print (f"Ignoring invalid line: { line .strip ()} " )
71
+ return pairs
72
+
73
+ def compute_precision_recall (node_to_clusters , pairs , threshold ):
74
+ """
75
+ Computes precision and recall based on the clusters and ground truth pairs.
76
+ Handles overlapping clusters where a node can belong to multiple clusters.
77
+
78
+ node_to_clusters: map from node id to a set of clusters
79
+ pairs: list of (u,v,w) triplets
80
+ """
81
+ TP = 0 # True Positive
82
+ TN = 0 # True Negative
83
+ FP = 0 # False Positive
84
+ FN = 0 # False Negative
85
+
86
+ for node1 , node2 , weight in pairs :
87
+ # Determine if the pair is positive or negative
88
+ is_positive = weight > threshold
89
+
90
+ # Determine if the nodes are in the same cluster (overlapping clusters)
91
+ if node1 in node_to_clusters and node2 in node_to_clusters :
92
+ clusters1 = node_to_clusters [node1 ]
93
+ clusters2 = node_to_clusters [node2 ]
94
+ in_same_cluster = bool (clusters1 & clusters2 ) # Check for non-empty intersection
95
+ else :
96
+ logging .warning ("skipping nodes %s, %s" , node1 , node2 )
97
+ # Nodes not found in clusters; skip this pair
98
+ continue
99
+
100
+ if is_positive :
101
+ if in_same_cluster :
102
+ TP += 1
103
+ else :
104
+ FN += 1
105
+ else :
106
+ if not in_same_cluster :
107
+ TN += 1
108
+ else :
109
+ FP += 1
110
+
111
+ # Calculate precision and recall
112
+ precision = TP / (TP + FP ) if (TP + FP ) > 0 else 0
113
+ recall = TP / (TP + FN ) if (TP + FN ) > 0 else 0
114
+
115
+ return precision , recall , TP , FP , TN , FN
116
+
117
+
118
+ def compute_precision_recall_pair (in_clustering , input_communities , out_statistics , stats_config , stats_dict ):
119
+ """
120
+ Compute pair precision and recall, and record result into stats_dict
121
+ """
122
+ stats_config = _config_str_to_dict (stats_config )
123
+ precision_recall_pair_threshold = stats_config ["precision_recall_pair_threshold" ]
124
+ f_score_param = stats_config .get ("f_score_param" , 1 )
125
+ print ()
126
+ print ("clustering file" , in_clustering )
127
+ print ("community file" , input_communities )
128
+ print ("stats file" , out_statistics )
129
+ print ("parameters, " , precision_recall_pair_threshold , f_score_param )
130
+
131
+ # Read clusters and ground truth pairs
132
+ clusters = read_clusters (in_clustering )
133
+ pairs = read_ground_truth_pairs (input_communities )
134
+
135
+ # Compute precision and recall
136
+ precision , recall , TP , FP , TN , FN = compute_precision_recall (clusters , pairs , precision_recall_pair_threshold )
137
+
138
+ f_score = 0
139
+ if precision != 0 and recall != 0 :
140
+ f_score = (1 + f_score_param * f_score_param ) * precision * recall / ((f_score_param * f_score_param * precision ) + recall )
141
+
142
+ stats_dict ["fScore_mean" ] = f_score
143
+ stats_dict ["communityPrecision_mean" ] = precision
144
+ stats_dict ["communityRecall_mean" ] = recall
145
+ stats_dict ["PrecisionRecallPairThreshold" ] = precision_recall_pair_threshold
146
+ stats_dict ["fScoreParam" ] = f_score_param
147
+
148
+ with open (out_statistics , 'w' , encoding = 'utf-8' ) as f :
149
+ json .dump (stats_dict , f , indent = 4 )
0 commit comments