Skip to content

Commit 56a410c

Browse files
authored
Merge pull request #68 from zhi-yi-huang/main
Refactored unit tests for FCI
2 parents e5e76e6 + e83d6e4 commit 56a410c

File tree

36 files changed

+1440
-772
lines changed

36 files changed

+1440
-772
lines changed

causallearn/graph/GeneralGraph.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def adjust_dpath(self, i: int, j: int):
6464
self.dpath = dpath
6565

6666
def reconstitute_dpath(self, edges: List[Edge]):
67+
self.dpath = np.zeros((self.num_vars, self.num_vars), np.dtype(int))
6768
for i in range(self.num_vars):
6869
self.adjust_dpath(i, i)
6970

@@ -73,7 +74,11 @@ def reconstitute_dpath(self, edges: List[Edge]):
7374
node2 = edge.get_node2()
7475
i = self.node_map[node1]
7576
j = self.node_map[node2]
76-
self.adjust_dpath(i, j)
77+
if self.is_parent_of(node1, node2):
78+
self.adjust_dpath(i, j)
79+
elif self.is_parent_of(node2, node1):
80+
self.adjust_dpath(j, i)
81+
7782

7883
def collect_ancestors(self, node: Node, ancestors: List[Node]):
7984
if node in ancestors:
@@ -503,13 +508,13 @@ def is_ancestor_of(self, node1: Node, node2: Node) -> bool:
503508
def is_child_of(self, node1: Node, node2: Node) -> bool:
504509
i = self.node_map[node1]
505510
j = self.node_map[node2]
506-
return self.graph[i, j] == 1 or self.graph[i, j] == Endpoint.ARROW_AND_ARROW.value
511+
return self.graph[i, j] == Endpoint.TAIL.value or self.graph[i, j] == Endpoint.ARROW_AND_ARROW.value
507512

508513
# Returns true iff node1 is a parent of node2.
509514
def is_parent_of(self, node1: Node, node2: Node) -> bool:
510515
i = self.node_map[node1]
511516
j = self.node_map[node2]
512-
return self.graph[j, i] == 1 or self.graph[j, i] == Endpoint.ARROW_AND_ARROW.value
517+
return self.graph[j, i] == Endpoint.ARROW.value and self.graph[i, j] == Endpoint.TAIL.value
513518

514519
# Returns true iff node1 is a proper ancestor of node2.
515520
def is_proper_ancestor_of(self, node1: Node, node2: Node) -> bool:
@@ -521,9 +526,7 @@ def is_proper_descendant_of(self, node1: Node, node2: Node) -> bool:
521526

522527
# Returns true iff node1 is a descendant of node2.
523528
def is_descendant_of(self, node1: Node, node2: Node) -> bool:
524-
i = self.node_map[node1]
525-
j = self.node_map[node2]
526-
return self.dpath[i, j] == 1
529+
return self.is_ancestor_of(node2, node1)
527530

528531
# Returns the edge connecting node1 and node2, provided a unique such edge exists.
529532
def get_edge(self, node1: Node, node2: Node) -> Edge | None:
@@ -763,6 +766,8 @@ def remove_edge(self, edge: Edge):
763766
end1 = edge.get_numerical_endpoint1()
764767
end2 = edge.get_numerical_endpoint2()
765768

769+
is_fully_directed = self.is_parent_of(node1, node2) or self.is_parent_of(node2, node1)
770+
766771
if out_of == Endpoint.TAIL_AND_ARROW.value and in_to == Endpoint.TAIL_AND_ARROW.value:
767772
if end1 == Endpoint.ARROW.value:
768773
self.graph[j, i] = -1
@@ -794,7 +799,8 @@ def remove_edge(self, edge: Edge):
794799
self.graph[j, i] = 0
795800
self.graph[i, j] = 0
796801

797-
self.reconstitute_dpath(self.get_graph_edges())
802+
if is_fully_directed:
803+
self.reconstitute_dpath(self.get_graph_edges())
798804

799805
# Removes the edge connecting the given two nodes, provided there is exactly one such edge.
800806
def remove_connecting_edge(self, node1: Node, node2: Node):

0 commit comments

Comments
 (0)