1- from causallearn .graph .Endpoint import Endpoint
21from causallearn .graph .Graph import Graph
32
43
@@ -7,8 +6,6 @@ class SHD:
76 Compute the Structural Hamming Distance (SHD) between two graphs. In simple terms, this is the number of edge
87 insertions, deletions or flips in order to transform one graph to another graph.
98 """
10- __SHD = 0
11-
129 def __init__ (self , truth : Graph , est : Graph ):
1310 """
1411 Compute and store the Structural Hamming Distance (SHD) between two graphs.
@@ -20,30 +17,18 @@ def __init__(self, truth: Graph, est: Graph):
2017 est :
2118 Estimated graph.
2219 """
23- nodes = truth .get_nodes ()
24- nodes_name = [ node .get_name () for node in nodes ]
25- self . __SHD : int = 0
20+ truth_node_map = { node . get_name (): node_id for node , node_id in truth .node_map . items ()}
21+ est_node_map = { node .get_name (): node_id for node , node_id in est . node_map . items ()}
22+ assert set ( truth_node_map . keys ()) == set ( est_node_map . keys ()), "The two graphs have different sets of node names."
2623
27- # Assumes the list of nodes for the two graphs are the same.
28- for i in list (range (0 , len (nodes ))):
29- for j in list (range (i + 1 , len (nodes ))):
30- if truth .get_edge (truth .get_node (nodes_name [i ]), truth .get_node (nodes_name [j ])) and (
31- not est .get_edge (est .get_node (nodes_name [i ]), est .get_node (nodes_name [j ]))):
32- self .__SHD += 1
33- if (not truth .get_edge (truth .get_node (nodes_name [i ]), truth .get_node (nodes_name [j ]))) and est .get_edge (
34- est .get_node (nodes_name [i ]), est .get_node (nodes_name [j ])):
35- self .__SHD += 1
36-
37- for i in list (range (0 , len (nodes ))):
38- for j in list (range (0 , len (nodes ))):
39- if not truth .get_edge (truth .get_node (nodes_name [i ]), truth .get_node (nodes_name [j ])):
40- continue
41- if not est .get_edge (est .get_node (nodes_name [i ]), est .get_node (nodes_name [j ])):
42- continue
43- if truth .get_endpoint (truth .get_node (nodes_name [i ]),
44- truth .get_node (nodes_name [j ])) == Endpoint .ARROW and est .get_endpoint (
45- est .get_node (nodes_name [j ]), est .get_node (nodes_name [i ])) == Endpoint .ARROW :
46- self .__SHD += 1
24+ self .__SHD : int = 0
25+ for node_i_name , truth_node_i_id in truth_node_map .items ():
26+ for node_j_name , truth_node_j_id in truth_node_map .items ():
27+ if truth_node_j_id < truth_node_i_id : continue # we allow `==' to care about the possibly self-loops.
28+ est_node_i_id , est_node_j_id = est_node_map [node_i_name ], est_node_map [node_j_name ]
29+ truth_ij_edge_endpoints = (truth .graph [truth_node_i_id , truth_node_j_id ], truth .graph [truth_node_j_id , truth_node_i_id ])
30+ est_ij_edge_endpoints = (est .graph [est_node_i_id , est_node_j_id ], est .graph [est_node_j_id , est_node_i_id ])
31+ if truth_ij_edge_endpoints != est_ij_edge_endpoints : self .__SHD += 1
4732
4833 def get_shd (self ) -> int :
4934 return self .__SHD
0 commit comments