Skip to content

Commit 1ea1183

Browse files
committed
make SHD+=1 for the cases other than directed inversion
1 parent 446b9a2 commit 1ea1183

File tree

1 file changed

+11
-26
lines changed

1 file changed

+11
-26
lines changed

causallearn/graph/SHD.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from causallearn.graph.Endpoint import Endpoint
21
from causallearn.graph.Graph import Graph
32

43

@@ -7,8 +6,6 @@ class SHD:
76
Compute the Structural Hamming Distance (SHD) between two graphs. In simple terms, this is the number of edge
87
insertions, deletions or flips in order to transform one graph to another graph.
98
"""
10-
__SHD = 0
11-
129
def __init__(self, truth: Graph, est: Graph):
1310
"""
1411
Compute and store the Structural Hamming Distance (SHD) between two graphs.
@@ -20,30 +17,18 @@ def __init__(self, truth: Graph, est: Graph):
2017
est :
2118
Estimated graph.
2219
"""
23-
nodes = truth.get_nodes()
24-
nodes_name = [node.get_name() for node in nodes]
25-
self.__SHD: int = 0
20+
truth_node_map = {node.get_name(): node_id for node, node_id in truth.node_map.items()}
21+
est_node_map = {node.get_name(): node_id for node, node_id in est.node_map.items()}
22+
assert set(truth_node_map.keys()) == set(est_node_map.keys()), "The two graphs have different sets of node names."
2623

27-
# Assumes the list of nodes for the two graphs are the same.
28-
for i in list(range(0, len(nodes))):
29-
for j in list(range(i + 1, len(nodes))):
30-
if truth.get_edge(truth.get_node(nodes_name[i]), truth.get_node(nodes_name[j])) and (
31-
not est.get_edge(est.get_node(nodes_name[i]), est.get_node(nodes_name[j]))):
32-
self.__SHD += 1
33-
if (not truth.get_edge(truth.get_node(nodes_name[i]), truth.get_node(nodes_name[j]))) and est.get_edge(
34-
est.get_node(nodes_name[i]), est.get_node(nodes_name[j])):
35-
self.__SHD += 1
36-
37-
for i in list(range(0, len(nodes))):
38-
for j in list(range(0, len(nodes))):
39-
if not truth.get_edge(truth.get_node(nodes_name[i]), truth.get_node(nodes_name[j])):
40-
continue
41-
if not est.get_edge(est.get_node(nodes_name[i]), est.get_node(nodes_name[j])):
42-
continue
43-
if truth.get_endpoint(truth.get_node(nodes_name[i]),
44-
truth.get_node(nodes_name[j])) == Endpoint.ARROW and est.get_endpoint(
45-
est.get_node(nodes_name[j]), est.get_node(nodes_name[i])) == Endpoint.ARROW:
46-
self.__SHD += 1
24+
self.__SHD: int = 0
25+
for node_i_name, truth_node_i_id in truth_node_map.items():
26+
for node_j_name, truth_node_j_id in truth_node_map.items():
27+
if truth_node_j_id < truth_node_i_id: continue # we allow `==' to care about the possibly self-loops.
28+
est_node_i_id, est_node_j_id = est_node_map[node_i_name], est_node_map[node_j_name]
29+
truth_ij_edge_endpoints = (truth.graph[truth_node_i_id, truth_node_j_id], truth.graph[truth_node_j_id, truth_node_i_id])
30+
est_ij_edge_endpoints = (est.graph[est_node_i_id, est_node_j_id], est.graph[est_node_j_id, est_node_i_id])
31+
if truth_ij_edge_endpoints != est_ij_edge_endpoints: self.__SHD += 1
4732

4833
def get_shd(self) -> int:
4934
return self.__SHD

0 commit comments

Comments
 (0)