Skip to content

Commit 5764c6f

Browse files
committed
Bug fix in finding undirected circle paths, and allow multiple ucp
1 parent f52fa6b commit 5764c6f

File tree

1 file changed

+38
-31
lines changed
  • causallearn/search/ConstraintBased

1 file changed

+38
-31
lines changed

causallearn/search/ConstraintBased/FCI.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,19 @@ def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool: ##
8282

8383
return False
8484

85-
def GetUncoveredCirclePath(node_from: Node, node_to: Node, G: Graph) -> Generator[Node] | None:
85+
def GetUncoveredCirclePath(node_from: Node, node_to: Node, G: Graph, exclude_node: List[Node]) -> Generator[Node] | None:
8686
Q = Queue()
8787
V = set()
8888

8989
path = [node_from]
9090

9191
for node_u in G.get_adjacent_nodes(node_from):
92+
if node_u in exclude_node:
93+
continue
9294
edge = G.get_edge(node_from, node_u)
9395
node_c = traverseCircle(node_from, edge)
9496

95-
if node_c is None:
97+
if node_c is None or node_c in exclude_node:
9698
continue
9799

98100
if not V.__contains__(node_c):
@@ -108,14 +110,13 @@ def GetUncoveredCirclePath(node_from: Node, node_to: Node, G: Graph) -> Generato
108110
edge = G.get_edge(node_t, node_u)
109111
node_c = traverseCircle(node_t, edge)
110112

111-
if node_c is None:
113+
if node_c is None or node_c in exclude_node:
112114
continue
113115

114116
if not V.__contains__(node_c):
115117
V.add(node_c)
116118
Q.put((node_c, path + [node_c]))
117119

118-
return None
119120

120121

121122
def existOnePathWithPossibleParents(previous, node_w: Node, node_x: Node, node_b: Node, graph: Graph) -> bool:
@@ -440,10 +441,32 @@ def ruleR5(graph: Graph, changeFlag: bool,
440441
double tail.
441442
"""
442443
nodes = graph.get_nodes()
444+
def orient_on_path_helper(path, node_A, node_B):
445+
# orient A - C, D - B
446+
edge = graph.get_edge(node_A, path[0])
447+
graph.remove_edge(edge)
448+
graph.add_edge(Edge(node_A, path[0], Endpoint.TAIL, Endpoint.TAIL))
449+
450+
edge = graph.get_edge(node_B, path[-1])
451+
graph.remove_edge(edge)
452+
graph.add_edge(Edge(node_B, path[-1], Endpoint.TAIL, Endpoint.TAIL))
453+
if verbose:
454+
print("Orienting edge A - C (Double tail): " + graph.get_edge(node_A, path[0]).__str__())
455+
print("Orienting edge B - D (Double tail): " + graph.get_edge(node_B, path[-1]).__str__())
456+
457+
# orient everything on the path to both tails
458+
for i in range(len(path) - 1):
459+
edge = graph.get_edge(path[i], path[i + 1])
460+
graph.remove_edge(edge)
461+
graph.add_edge(Edge(path[i], path[i + 1], Endpoint.TAIL, Endpoint.TAIL))
462+
if verbose:
463+
print("Orienting edge (Double tail): " + graph.get_edge(path[i], path[i + 1]).__str__())
464+
443465
for node_B in nodes:
444466
intoBCircles = graph.get_nodes_into(node_B, Endpoint.CIRCLE)
445467

446468
for node_A in intoBCircles:
469+
found_paths_between_AB = []
447470
if graph.get_endpoint(node_B, node_A) != Endpoint.CIRCLE:
448471
continue
449472
else:
@@ -467,33 +490,17 @@ def ruleR5(graph: Graph, changeFlag: bool,
467490
for node_D in b_circle_adj_nodes_set:
468491
if graph.is_adjacent_to(node_A, node_D):
469492
continue
470-
paths = GetUncoveredCirclePath(node_from=node_C, node_to=node_D, G=graph)
471-
if paths is not None:
472-
# Mark the change if find at least one path
473-
changeFlag = True
474-
for path in paths:
475-
# orient A - C, D - B
476-
edge = graph.get_edge(node_A, path[0])
477-
graph.remove_edge(edge)
478-
graph.add_edge(Edge(node_A, path[0], Endpoint.TAIL, Endpoint.TAIL))
479-
480-
edge = graph.get_edge(node_B, path[-1])
481-
graph.remove_edge(edge)
482-
graph.add_edge(Edge(node_B, path[-1], Endpoint.TAIL, Endpoint.TAIL))
483-
if verbose:
484-
print("Orienting edge (Double tail): " + graph.get_edge(node_A, path[0]).__str__())
485-
print("Orienting edge (Double tail): " + graph.get_edge(node_B, path[-1]).__str__())
486-
487-
# orient everything on the path to both tails
488-
for i in range(len(path) - 1):
489-
edge = graph.get_edge(path[i], path[i + 1])
490-
graph.remove_edge(edge)
491-
graph.add_edge(Edge(path[i], path[i + 1], Endpoint.TAIL, Endpoint.TAIL))
492-
if verbose:
493-
print("Orienting edge (Double tail): " + graph.get_edge(path[i], path[i + 1]).__str__())
494-
495-
continue
496-
493+
paths = GetUncoveredCirclePath(node_from=node_C, node_to=node_D, G=graph, exclude_node=[node_A, node_B]) # get the uncovered circle path between C and D, excluding A and B
494+
found_paths_between_AB.append(paths)
495+
496+
# Orient the uncovered circle path between A and B
497+
for paths in found_paths_between_AB:
498+
for path in paths:
499+
changeFlag = True
500+
if verbose:
501+
print("Find uncovered circle path between A and B: " + graph.get_edge(node_A, node_B).__str__())
502+
orient_on_path_helper(path, node_A, node_B)
503+
497504
return changeFlag
498505

499506
def ruleR6(graph: Graph, changeFlag: bool,

0 commit comments

Comments
 (0)