Skip to content

Commit 285c6c2

Browse files
committed
Add R5, 6, 7 for the Augmented FCI
1 parent 5d788b9 commit 285c6c2

File tree

1 file changed

+189
-1
lines changed
  • causallearn/search/ConstraintBased

1 file changed

+189
-1
lines changed

causallearn/search/ConstraintBased/FCI.py

Lines changed: 189 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,19 @@
1717
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
1818
from itertools import combinations
1919

20+
def is_uncovered_path(nodes: List[Node], G: Graph) -> bool:
21+
"""
22+
Determines whether the given path is an uncovered path in this graph.
23+
24+
A path is an uncovered path if no two nonconsecutive nodes (Vi-1 and Vi+1) in the path are
25+
adjacent.
26+
"""
27+
for i in range(len(nodes) - 2):
28+
if G.is_adjacent_to(nodes[i], nodes[i + 2]):
29+
return False
30+
return True
31+
32+
2033
def traverseSemiDirected(node: Node, edge: Edge) -> Node | None:
2134
if node == edge.get_node1():
2235
if edge.get_endpoint1() == Endpoint.TAIL or edge.get_endpoint1() == Endpoint.CIRCLE:
@@ -26,8 +39,17 @@ def traverseSemiDirected(node: Node, edge: Edge) -> Node | None:
2639
return edge.get_node1()
2740
return None
2841

42+
def traverseCircle(node: Node, edge: Edge) -> Node | None:
43+
if node == edge.get_node1():
44+
if edge.get_endpoint1() == Endpoint.CIRCLE and edge.get_endpoint2() == Endpoint.CIRCLE:
45+
return edge.get_node2()
46+
elif node == edge.get_node2():
47+
if edge.get_endpoint1() == Endpoint.CIRCLE or edge.get_endpoint2() == Endpoint.CIRCLE:
48+
return edge.get_node1()
49+
return None
50+
2951

30-
def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool:
52+
def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool: ## TODO: Now it does not detect whether the path is an uncovered path
3153
Q = Queue()
3254
V = set()
3355

@@ -60,6 +82,41 @@ def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool:
6082

6183
return False
6284

85+
def GetUncoveredCirclePath(node_from: Node, node_to: Node, G: Graph) -> List[Node] | None:
86+
Q = Queue()
87+
V = set()
88+
89+
path = [node_from]
90+
91+
for node_u in G.get_adjacent_nodes(node_from):
92+
edge = G.get_edge(node_from, node_u)
93+
node_c = traverseCircle(node_from, edge)
94+
95+
if node_c is None:
96+
continue
97+
98+
if not V.__contains__(node_c):
99+
V.add(node_c)
100+
Q.put((node_c, path + [node_c]))
101+
102+
while not Q.empty():
103+
node_t, path = Q.get_nowait()
104+
if node_t == node_to and is_uncovered_path(path, G):
105+
return path
106+
107+
for node_u in G.get_adjacent_nodes(node_t):
108+
edge = G.get_edge(node_t, node_u)
109+
node_c = traverseCircle(node_t, edge)
110+
111+
if node_c is None:
112+
continue
113+
114+
if not V.__contains__(node_c):
115+
V.add(node_c)
116+
Q.put((node_c, path + [node_c]))
117+
118+
return None
119+
63120

64121
def existOnePathWithPossibleParents(previous, node_w: Node, node_x: Node, node_b: Node, graph: Graph) -> bool:
65122
if node_w == node_x:
@@ -371,6 +428,126 @@ def ruleR3(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], bk: Backgrou
371428
changeFlag = True
372429
return changeFlag
373430

431+
def ruleR5(graph: Graph, changeFlag: bool,
432+
verbose: bool = False) -> bool:
433+
"""
434+
Rule R5 of the FCI algorithm.
435+
by Jiji Zhang, 2008, "On the completeness of orientation rules for causal discovery in the presence of latent confounders and selection bias"]
436+
437+
This function orients any edge that is part of an uncovered circle path between two nodes A and B,
438+
if such a path exists. The path must start and end with a circle edge and must be uncovered, i.e. the
439+
nodes on the path must not be adjacent to A or B. The orientation of the edges on the path is set to
440+
double tail.
441+
"""
442+
nodes = graph.get_nodes()
443+
for node_B in nodes:
444+
intoBCircles = graph.get_nodes_into(node_B, Endpoint.CIRCLE)
445+
446+
for node_A in intoBCircles:
447+
if graph.get_endpoint(node_B, node_A) != Endpoint.CIRCLE:
448+
continue
449+
else:
450+
# Check if there is an uncovered circle path between A and B (A o-o C .. D o-o B)
451+
# s.t. A is not adjacent to D and B is not adjacent to C
452+
a_node_idx = graph.node_map[node_A]
453+
b_node_idx = graph.node_map[node_B]
454+
a_adj_nodes = graph.get_adjacent_nodes(node_A)
455+
b_adj_nodes = graph.get_adjacent_nodes(node_B)
456+
457+
# get the adjacent nodes with circle edges of A and B
458+
a_circle_adj_nodes_set = [node for node in a_adj_nodes if graph.node_map[node] != a_node_idx and graph.node_map[node]!= b_node_idx
459+
and graph.get_endpoint(node, node_A) == Endpoint.CIRCLE and graph.get_endpoint(node_A, node) == Endpoint.CIRCLE]
460+
b_circle_adj_nodes_set = [node for node in b_adj_nodes if graph.node_map[node] != a_node_idx and graph.node_map[node]!= b_node_idx
461+
and graph.get_endpoint(node, node_B) == Endpoint.CIRCLE and graph.get_endpoint(node_B, node) == Endpoint.CIRCLE]
462+
463+
# get the adjacent nodes with circle edges of A and B that is non adjacent to B and A, respectively
464+
for node_C in a_circle_adj_nodes_set:
465+
if graph.is_adjacent_to(node_B, node_C):
466+
continue
467+
for node_D in b_circle_adj_nodes_set:
468+
if graph.is_adjacent_to(node_A, node_D):
469+
continue
470+
path = GetUncoveredCirclePath(node_from=node_C, node_to=node_D, G=graph)
471+
if path is not None:
472+
# pdb.set_trace()
473+
changeFlag = True
474+
# orient A - C, D - B
475+
edge = graph.get_edge(node_A, path[0])
476+
graph.remove_edge(edge)
477+
graph.add_edge(Edge(node_A, path[0], Endpoint.TAIL, Endpoint.TAIL))
478+
479+
edge = graph.get_edge(node_B, path[-1])
480+
graph.remove_edge(edge)
481+
graph.add_edge(Edge(node_B, path[-1], Endpoint.TAIL, Endpoint.TAIL))
482+
if verbose:
483+
print("Orienting edge (Double tail): " + graph.get_edge(node_A, path[0]).__str__())
484+
print("Orienting edge (Double tail): " + graph.get_edge(node_B, path[-1]).__str__())
485+
486+
# orient everything on the path to both tails
487+
for i in range(len(path) - 1):
488+
edge = graph.get_edge(path[i], path[i + 1])
489+
graph.remove_edge(edge)
490+
graph.add_edge(Edge(path[i], path[i + 1], Endpoint.TAIL, Endpoint.TAIL))
491+
if verbose:
492+
print("Orienting edge (Double tail): " + graph.get_edge(path[i], path[i + 1]).__str__())
493+
494+
# if not is_arrow_point_allowed(node_A, path[0], graph, bk):
495+
# break
496+
# if not is_arrow_point_allowed(node_B, path[-1], graph, bk):
497+
# break
498+
499+
continue
500+
501+
return changeFlag
502+
503+
def ruleR6(graph: Graph, changeFlag: bool,
504+
verbose: bool = False) -> bool:
505+
nodes = graph.get_nodes()
506+
507+
for node_B in nodes:
508+
# Find A - B
509+
intoBTails = graph.get_nodes_into(node_B, Endpoint.TAIL)
510+
exist = False
511+
for node_A in intoBTails:
512+
if graph.get_endpoint(node_B, node_A) == Endpoint.TAIL:
513+
exist = True
514+
if not exist:
515+
continue
516+
# Find B o-*C
517+
intoBCircles = graph.get_nodes_into(node_B, Endpoint.CIRCLE)
518+
for node_C in intoBCircles:
519+
changeFlag = True
520+
edge = graph.get_edge(node_B, node_C)
521+
graph.remove_edge(edge)
522+
graph.add_edge(Edge(node_B, node_C, Endpoint.TAIL, edge.get_proximal_endpoint(node_C)))
523+
if verbose:
524+
print("Orienting edge by rule 6): " + graph.get_edge(node_B, node_C).__str__())
525+
526+
return changeFlag
527+
528+
529+
def ruleR7(graph: Graph, changeFlag: bool,
530+
verbose: bool = False) -> bool:
531+
nodes = graph.get_nodes()
532+
533+
for node_B in nodes:
534+
# Find A -o B
535+
intoBCircles = graph.get_nodes_into(node_B, Endpoint.CIRCLE)
536+
node_A_list = [node for node in intoBCircles if graph.get_endpoint(node_B, node) == Endpoint.TAIL]
537+
538+
# Find B o-*C
539+
for node_C in intoBCircles:
540+
# pdb.set_trace()
541+
for node_A in node_A_list:
542+
# pdb.set_trace()
543+
if not graph.is_adjacent_to(node_A, node_C):
544+
changeFlag = True
545+
edge = graph.get_edge(node_B, node_C)
546+
graph.remove_edge(edge)
547+
graph.add_edge(Edge(node_B, node_C, Endpoint.TAIL, edge.get_proximal_endpoint(node_C)))
548+
if verbose:
549+
print("Orienting edge by rule 7): " + graph.get_edge(node_B, node_C).__str__())
550+
return changeFlag
374551

375552
def getPath(node_c: Node, previous) -> List[Node]:
376553
l = []
@@ -896,6 +1073,7 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
8961073
graph, sep_sets, test_results = fas(dataset, nodes, independence_test_method=independence_test_method, alpha=alpha,
8971074
knowledge=background_knowledge, depth=depth, verbose=verbose, show_progress=show_progress)
8981075

1076+
# pdb.set_trace()
8991077
reorientAllWith(graph, Endpoint.CIRCLE)
9001078

9011079
rule0(graph, nodes, sep_sets, background_knowledge, verbose)
@@ -926,8 +1104,18 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
9261104
if verbose:
9271105
print("Epoch")
9281106

1107+
# rule 5
1108+
change_flag = ruleR5(graph, change_flag, verbose)
1109+
1110+
# rule 6
1111+
change_flag = ruleR6(graph, change_flag, verbose)
1112+
1113+
# rule 7
1114+
change_flag = ruleR7(graph, change_flag, verbose)
1115+
9291116
# rule 8
9301117
change_flag = rule8(graph,nodes)
1118+
9311119
# rule 9
9321120
change_flag = rule9(graph, nodes)
9331121
# rule 10

0 commit comments

Comments
 (0)