Skip to content

Commit f52fa6b

Browse files
committed
Minor changes
1 parent 491560e commit f52fa6b

File tree

1 file changed

+26
-30
lines changed
  • causallearn/search/ConstraintBased

1 file changed

+26
-30
lines changed

causallearn/search/ConstraintBased/FCI.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import warnings
44
from queue import Queue
5-
from typing import List, Set, Tuple, Dict
5+
from typing import List, Set, Tuple, Dict, Generator
66
from numpy import ndarray
77

88
from causallearn.graph.Edge import Edge
@@ -44,7 +44,7 @@ def traverseCircle(node: Node, edge: Edge) -> Node | None:
4444
if edge.get_endpoint1() == Endpoint.CIRCLE and edge.get_endpoint2() == Endpoint.CIRCLE:
4545
return edge.get_node2()
4646
elif node == edge.get_node2():
47-
if edge.get_endpoint1() == Endpoint.CIRCLE or edge.get_endpoint2() == Endpoint.CIRCLE:
47+
if edge.get_endpoint1() == Endpoint.CIRCLE and edge.get_endpoint2() == Endpoint.CIRCLE:
4848
return edge.get_node1()
4949
return None
5050

@@ -82,7 +82,7 @@ def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool: ##
8282

8383
return False
8484

85-
def GetUncoveredCirclePath(node_from: Node, node_to: Node, G: Graph) -> List[Node] | None:
85+
def GetUncoveredCirclePath(node_from: Node, node_to: Node, G: Graph) -> Generator[Node] | None:
8686
Q = Queue()
8787
V = set()
8888

@@ -102,7 +102,7 @@ def GetUncoveredCirclePath(node_from: Node, node_to: Node, G: Graph) -> List[Nod
102102
while not Q.empty():
103103
node_t, path = Q.get_nowait()
104104
if node_t == node_to and is_uncovered_path(path, G):
105-
return path
105+
yield path
106106

107107
for node_u in G.get_adjacent_nodes(node_t):
108108
edge = G.get_edge(node_t, node_u)
@@ -467,35 +467,31 @@ def ruleR5(graph: Graph, changeFlag: bool,
467467
for node_D in b_circle_adj_nodes_set:
468468
if graph.is_adjacent_to(node_A, node_D):
469469
continue
470-
path = GetUncoveredCirclePath(node_from=node_C, node_to=node_D, G=graph)
471-
if path is not None:
472-
# pdb.set_trace()
470+
paths = GetUncoveredCirclePath(node_from=node_C, node_to=node_D, G=graph)
471+
if paths is not None:
472+
# Mark the change if find at least one path
473473
changeFlag = True
474-
# orient A - C, D - B
475-
edge = graph.get_edge(node_A, path[0])
476-
graph.remove_edge(edge)
477-
graph.add_edge(Edge(node_A, path[0], Endpoint.TAIL, Endpoint.TAIL))
478-
479-
edge = graph.get_edge(node_B, path[-1])
480-
graph.remove_edge(edge)
481-
graph.add_edge(Edge(node_B, path[-1], Endpoint.TAIL, Endpoint.TAIL))
482-
if verbose:
483-
print("Orienting edge (Double tail): " + graph.get_edge(node_A, path[0]).__str__())
484-
print("Orienting edge (Double tail): " + graph.get_edge(node_B, path[-1]).__str__())
485-
486-
# orient everything on the path to both tails
487-
for i in range(len(path) - 1):
488-
edge = graph.get_edge(path[i], path[i + 1])
474+
for path in paths:
475+
# orient A - C, D - B
476+
edge = graph.get_edge(node_A, path[0])
489477
graph.remove_edge(edge)
490-
graph.add_edge(Edge(path[i], path[i + 1], Endpoint.TAIL, Endpoint.TAIL))
491-
if verbose:
492-
print("Orienting edge (Double tail): " + graph.get_edge(path[i], path[i + 1]).__str__())
493-
494-
# if not is_arrow_point_allowed(node_A, path[0], graph, bk):
495-
# break
496-
# if not is_arrow_point_allowed(node_B, path[-1], graph, bk):
497-
# break
478+
graph.add_edge(Edge(node_A, path[0], Endpoint.TAIL, Endpoint.TAIL))
498479

480+
edge = graph.get_edge(node_B, path[-1])
481+
graph.remove_edge(edge)
482+
graph.add_edge(Edge(node_B, path[-1], Endpoint.TAIL, Endpoint.TAIL))
483+
if verbose:
484+
print("Orienting edge (Double tail): " + graph.get_edge(node_A, path[0]).__str__())
485+
print("Orienting edge (Double tail): " + graph.get_edge(node_B, path[-1]).__str__())
486+
487+
# orient everything on the path to both tails
488+
for i in range(len(path) - 1):
489+
edge = graph.get_edge(path[i], path[i + 1])
490+
graph.remove_edge(edge)
491+
graph.add_edge(Edge(path[i], path[i + 1], Endpoint.TAIL, Endpoint.TAIL))
492+
if verbose:
493+
print("Orienting edge (Double tail): " + graph.get_edge(path[i], path[i + 1]).__str__())
494+
499495
continue
500496

501497
return changeFlag

0 commit comments

Comments
 (0)