1717from causallearn .utils .PCUtils .BackgroundKnowledge import BackgroundKnowledge
1818from itertools import combinations
1919
20+ def is_uncovered_path (nodes : List [Node ], G : Graph ) -> bool :
21+ """
22+ Determines whether the given path is an uncovered path in this graph.
23+
24+ A path is an uncovered path if no two nonconsecutive nodes (Vi-1 and Vi+1) in the path are
25+ adjacent.
26+ """
27+ for i in range (len (nodes ) - 2 ):
28+ if G .is_adjacent_to (nodes [i ], nodes [i + 2 ]):
29+ return False
30+ return True
31+
32+
2033def traverseSemiDirected (node : Node , edge : Edge ) -> Node | None :
2134 if node == edge .get_node1 ():
2235 if edge .get_endpoint1 () == Endpoint .TAIL or edge .get_endpoint1 () == Endpoint .CIRCLE :
@@ -26,8 +39,17 @@ def traverseSemiDirected(node: Node, edge: Edge) -> Node | None:
2639 return edge .get_node1 ()
2740 return None
2841
42+ def traverseCircle (node : Node , edge : Edge ) -> Node | None :
43+ if node == edge .get_node1 ():
44+ if edge .get_endpoint1 () == Endpoint .CIRCLE and edge .get_endpoint2 () == Endpoint .CIRCLE :
45+ return edge .get_node2 ()
46+ elif node == edge .get_node2 ():
47+ if edge .get_endpoint1 () == Endpoint .CIRCLE or edge .get_endpoint2 () == Endpoint .CIRCLE :
48+ return edge .get_node1 ()
49+ return None
50+
2951
30- def existsSemiDirectedPath (node_from : Node , node_to : Node , G : Graph ) -> bool :
52+ def existsSemiDirectedPath (node_from : Node , node_to : Node , G : Graph ) -> bool : ## TODO: Now it does not detect whether the path is an uncovered path
3153 Q = Queue ()
3254 V = set ()
3355
@@ -60,6 +82,41 @@ def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool:
6082
6183 return False
6284
85+ def GetUncoveredCirclePath (node_from : Node , node_to : Node , G : Graph ) -> List [Node ] | None :
86+ Q = Queue ()
87+ V = set ()
88+
89+ path = [node_from ]
90+
91+ for node_u in G .get_adjacent_nodes (node_from ):
92+ edge = G .get_edge (node_from , node_u )
93+ node_c = traverseCircle (node_from , edge )
94+
95+ if node_c is None :
96+ continue
97+
98+ if not V .__contains__ (node_c ):
99+ V .add (node_c )
100+ Q .put ((node_c , path + [node_c ]))
101+
102+ while not Q .empty ():
103+ node_t , path = Q .get_nowait ()
104+ if node_t == node_to and is_uncovered_path (path , G ):
105+ return path
106+
107+ for node_u in G .get_adjacent_nodes (node_t ):
108+ edge = G .get_edge (node_t , node_u )
109+ node_c = traverseCircle (node_t , edge )
110+
111+ if node_c is None :
112+ continue
113+
114+ if not V .__contains__ (node_c ):
115+ V .add (node_c )
116+ Q .put ((node_c , path + [node_c ]))
117+
118+ return None
119+
63120
64121def existOnePathWithPossibleParents (previous , node_w : Node , node_x : Node , node_b : Node , graph : Graph ) -> bool :
65122 if node_w == node_x :
@@ -371,6 +428,126 @@ def ruleR3(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], bk: Backgrou
371428 changeFlag = True
372429 return changeFlag
373430
431+ def ruleR5 (graph : Graph , changeFlag : bool ,
432+ verbose : bool = False ) -> bool :
433+ """
434+ Rule R5 of the FCI algorithm.
435+ by Jiji Zhang, 2008, "On the completeness of orientation rules for causal discovery in the presence of latent confounders and selection bias"]
436+
437+ This function orients any edge that is part of an uncovered circle path between two nodes A and B,
438+ if such a path exists. The path must start and end with a circle edge and must be uncovered, i.e. the
439+ nodes on the path must not be adjacent to A or B. The orientation of the edges on the path is set to
440+ double tail.
441+ """
442+ nodes = graph .get_nodes ()
443+ for node_B in nodes :
444+ intoBCircles = graph .get_nodes_into (node_B , Endpoint .CIRCLE )
445+
446+ for node_A in intoBCircles :
447+ if graph .get_endpoint (node_B , node_A ) != Endpoint .CIRCLE :
448+ continue
449+ else :
450+ # Check if there is an uncovered circle path between A and B (A o-o C .. D o-o B)
451+ # s.t. A is not adjacent to D and B is not adjacent to C
452+ a_node_idx = graph .node_map [node_A ]
453+ b_node_idx = graph .node_map [node_B ]
454+ a_adj_nodes = graph .get_adjacent_nodes (node_A )
455+ b_adj_nodes = graph .get_adjacent_nodes (node_B )
456+
457+ # get the adjacent nodes with circle edges of A and B
458+ a_circle_adj_nodes_set = [node for node in a_adj_nodes if graph .node_map [node ] != a_node_idx and graph .node_map [node ]!= b_node_idx
459+ and graph .get_endpoint (node , node_A ) == Endpoint .CIRCLE and graph .get_endpoint (node_A , node ) == Endpoint .CIRCLE ]
460+ b_circle_adj_nodes_set = [node for node in b_adj_nodes if graph .node_map [node ] != a_node_idx and graph .node_map [node ]!= b_node_idx
461+ and graph .get_endpoint (node , node_B ) == Endpoint .CIRCLE and graph .get_endpoint (node_B , node ) == Endpoint .CIRCLE ]
462+
463+ # get the adjacent nodes with circle edges of A and B that is non adjacent to B and A, respectively
464+ for node_C in a_circle_adj_nodes_set :
465+ if graph .is_adjacent_to (node_B , node_C ):
466+ continue
467+ for node_D in b_circle_adj_nodes_set :
468+ if graph .is_adjacent_to (node_A , node_D ):
469+ continue
470+ path = GetUncoveredCirclePath (node_from = node_C , node_to = node_D , G = graph )
471+ if path is not None :
472+ # pdb.set_trace()
473+ changeFlag = True
474+ # orient A - C, D - B
475+ edge = graph .get_edge (node_A , path [0 ])
476+ graph .remove_edge (edge )
477+ graph .add_edge (Edge (node_A , path [0 ], Endpoint .TAIL , Endpoint .TAIL ))
478+
479+ edge = graph .get_edge (node_B , path [- 1 ])
480+ graph .remove_edge (edge )
481+ graph .add_edge (Edge (node_B , path [- 1 ], Endpoint .TAIL , Endpoint .TAIL ))
482+ if verbose :
483+ print ("Orienting edge (Double tail): " + graph .get_edge (node_A , path [0 ]).__str__ ())
484+ print ("Orienting edge (Double tail): " + graph .get_edge (node_B , path [- 1 ]).__str__ ())
485+
486+ # orient everything on the path to both tails
487+ for i in range (len (path ) - 1 ):
488+ edge = graph .get_edge (path [i ], path [i + 1 ])
489+ graph .remove_edge (edge )
490+ graph .add_edge (Edge (path [i ], path [i + 1 ], Endpoint .TAIL , Endpoint .TAIL ))
491+ if verbose :
492+ print ("Orienting edge (Double tail): " + graph .get_edge (path [i ], path [i + 1 ]).__str__ ())
493+
494+ # if not is_arrow_point_allowed(node_A, path[0], graph, bk):
495+ # break
496+ # if not is_arrow_point_allowed(node_B, path[-1], graph, bk):
497+ # break
498+
499+ continue
500+
501+ return changeFlag
502+
503+ def ruleR6 (graph : Graph , changeFlag : bool ,
504+ verbose : bool = False ) -> bool :
505+ nodes = graph .get_nodes ()
506+
507+ for node_B in nodes :
508+ # Find A - B
509+ intoBTails = graph .get_nodes_into (node_B , Endpoint .TAIL )
510+ exist = False
511+ for node_A in intoBTails :
512+ if graph .get_endpoint (node_B , node_A ) == Endpoint .TAIL :
513+ exist = True
514+ if not exist :
515+ continue
516+ # Find B o-*C
517+ intoBCircles = graph .get_nodes_into (node_B , Endpoint .CIRCLE )
518+ for node_C in intoBCircles :
519+ changeFlag = True
520+ edge = graph .get_edge (node_B , node_C )
521+ graph .remove_edge (edge )
522+ graph .add_edge (Edge (node_B , node_C , Endpoint .TAIL , edge .get_proximal_endpoint (node_C )))
523+ if verbose :
524+ print ("Orienting edge by rule 6): " + graph .get_edge (node_B , node_C ).__str__ ())
525+
526+ return changeFlag
527+
528+
529+ def ruleR7 (graph : Graph , changeFlag : bool ,
530+ verbose : bool = False ) -> bool :
531+ nodes = graph .get_nodes ()
532+
533+ for node_B in nodes :
534+ # Find A -o B
535+ intoBCircles = graph .get_nodes_into (node_B , Endpoint .CIRCLE )
536+ node_A_list = [node for node in intoBCircles if graph .get_endpoint (node_B , node ) == Endpoint .TAIL ]
537+
538+ # Find B o-*C
539+ for node_C in intoBCircles :
540+ # pdb.set_trace()
541+ for node_A in node_A_list :
542+ # pdb.set_trace()
543+ if not graph .is_adjacent_to (node_A , node_C ):
544+ changeFlag = True
545+ edge = graph .get_edge (node_B , node_C )
546+ graph .remove_edge (edge )
547+ graph .add_edge (Edge (node_B , node_C , Endpoint .TAIL , edge .get_proximal_endpoint (node_C )))
548+ if verbose :
549+ print ("Orienting edge by rule 7): " + graph .get_edge (node_B , node_C ).__str__ ())
550+ return changeFlag
374551
375552def getPath (node_c : Node , previous ) -> List [Node ]:
376553 l = []
@@ -896,6 +1073,7 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
8961073 graph , sep_sets , test_results = fas (dataset , nodes , independence_test_method = independence_test_method , alpha = alpha ,
8971074 knowledge = background_knowledge , depth = depth , verbose = verbose , show_progress = show_progress )
8981075
1076+ # pdb.set_trace()
8991077 reorientAllWith (graph , Endpoint .CIRCLE )
9001078
9011079 rule0 (graph , nodes , sep_sets , background_knowledge , verbose )
@@ -926,8 +1104,18 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
9261104 if verbose :
9271105 print ("Epoch" )
9281106
1107+ # rule 5
1108+ change_flag = ruleR5 (graph , change_flag , verbose )
1109+
1110+ # rule 6
1111+ change_flag = ruleR6 (graph , change_flag , verbose )
1112+
1113+ # rule 7
1114+ change_flag = ruleR7 (graph , change_flag , verbose )
1115+
9291116 # rule 8
9301117 change_flag = rule8 (graph ,nodes )
1118+
9311119 # rule 9
9321120 change_flag = rule9 (graph , nodes )
9331121 # rule 10
0 commit comments