22
33import warnings
44from queue import Queue
5- from typing import List , Set , Tuple , Dict
5+ from typing import List , Set , Tuple , Dict , Generator
66from numpy import ndarray
77
88from causallearn .graph .Edge import Edge
1717from causallearn .utils .PCUtils .BackgroundKnowledge import BackgroundKnowledge
1818from itertools import combinations
1919
20+ def is_uncovered_path (nodes : List [Node ], G : Graph ) -> bool :
21+ """
22+ Determines whether the given path is an uncovered path in this graph.
23+
24+ A path is an uncovered path if no two nonconsecutive nodes (Vi-1 and Vi+1) in the path are
25+ adjacent.
26+ """
27+ for i in range (len (nodes ) - 2 ):
28+ if G .is_adjacent_to (nodes [i ], nodes [i + 2 ]):
29+ return False
30+ return True
31+
32+
2033def traverseSemiDirected (node : Node , edge : Edge ) -> Node | None :
2134 if node == edge .get_node1 ():
2235 if edge .get_endpoint1 () == Endpoint .TAIL or edge .get_endpoint1 () == Endpoint .CIRCLE :
@@ -26,8 +39,17 @@ def traverseSemiDirected(node: Node, edge: Edge) -> Node | None:
2639 return edge .get_node1 ()
2740 return None
2841
42+ def traverseCircle (node : Node , edge : Edge ) -> Node | None :
43+ if node == edge .get_node1 ():
44+ if edge .get_endpoint1 () == Endpoint .CIRCLE and edge .get_endpoint2 () == Endpoint .CIRCLE :
45+ return edge .get_node2 ()
46+ elif node == edge .get_node2 ():
47+ if edge .get_endpoint1 () == Endpoint .CIRCLE and edge .get_endpoint2 () == Endpoint .CIRCLE :
48+ return edge .get_node1 ()
49+ return None
50+
2951
30- def existsSemiDirectedPath (node_from : Node , node_to : Node , G : Graph ) -> bool :
52+ def existsSemiDirectedPath (node_from : Node , node_to : Node , G : Graph ) -> bool : ## TODO: Now it does not detect whether the path is an uncovered path
3153 Q = Queue ()
3254 V = set ()
3355
@@ -60,6 +82,42 @@ def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool:
6082
6183 return False
6284
85+ def GetUncoveredCirclePath (node_from : Node , node_to : Node , G : Graph , exclude_node : List [Node ]) -> Generator [Node ] | None :
86+ Q = Queue ()
87+ V = set ()
88+
89+ path = [node_from ]
90+
91+ for node_u in G .get_adjacent_nodes (node_from ):
92+ if node_u in exclude_node :
93+ continue
94+ edge = G .get_edge (node_from , node_u )
95+ node_c = traverseCircle (node_from , edge )
96+
97+ if node_c is None or node_c in exclude_node :
98+ continue
99+
100+ if not V .__contains__ (node_c ):
101+ V .add (node_c )
102+ Q .put ((node_c , path + [node_c ]))
103+
104+ while not Q .empty ():
105+ node_t , path = Q .get_nowait ()
106+ if node_t == node_to and is_uncovered_path (path , G ):
107+ yield path
108+
109+ for node_u in G .get_adjacent_nodes (node_t ):
110+ edge = G .get_edge (node_t , node_u )
111+ node_c = traverseCircle (node_t , edge )
112+
113+ if node_c is None or node_c in exclude_node :
114+ continue
115+
116+ if not V .__contains__ (node_c ):
117+ V .add (node_c )
118+ Q .put ((node_c , path + [node_c ]))
119+
120+
63121
64122def existOnePathWithPossibleParents (previous , node_w : Node , node_x : Node , node_b : Node , graph : Graph ) -> bool :
65123 if node_w == node_x :
@@ -371,6 +429,131 @@ def ruleR3(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], bk: Backgrou
371429 changeFlag = True
372430 return changeFlag
373431
432+ def ruleR5 (graph : Graph , changeFlag : bool ,
433+ verbose : bool = False ) -> bool :
434+ """
435+ Rule R5 of the FCI algorithm.
436+ by Jiji Zhang, 2008, "On the completeness of orientation rules for causal discovery in the presence of latent confounders and selection bias"]
437+
438+ This function orients any edge that is part of an uncovered circle path between two nodes A and B,
439+ if such a path exists. The path must start and end with a circle edge and must be uncovered, i.e. the
440+ nodes on the path must not be adjacent to A or B. The orientation of the edges on the path is set to
441+ double tail.
442+ """
443+ nodes = graph .get_nodes ()
444+ def orient_on_path_helper (path , node_A , node_B ):
445+ # orient A - C, D - B
446+ edge = graph .get_edge (node_A , path [0 ])
447+ graph .remove_edge (edge )
448+ graph .add_edge (Edge (node_A , path [0 ], Endpoint .TAIL , Endpoint .TAIL ))
449+
450+ edge = graph .get_edge (node_B , path [- 1 ])
451+ graph .remove_edge (edge )
452+ graph .add_edge (Edge (node_B , path [- 1 ], Endpoint .TAIL , Endpoint .TAIL ))
453+ if verbose :
454+ print ("Orienting edge A - C (Double tail): " + graph .get_edge (node_A , path [0 ]).__str__ ())
455+ print ("Orienting edge B - D (Double tail): " + graph .get_edge (node_B , path [- 1 ]).__str__ ())
456+
457+ # orient everything on the path to both tails
458+ for i in range (len (path ) - 1 ):
459+ edge = graph .get_edge (path [i ], path [i + 1 ])
460+ graph .remove_edge (edge )
461+ graph .add_edge (Edge (path [i ], path [i + 1 ], Endpoint .TAIL , Endpoint .TAIL ))
462+ if verbose :
463+ print ("Orienting edge (Double tail): " + graph .get_edge (path [i ], path [i + 1 ]).__str__ ())
464+
465+ for node_B in nodes :
466+ intoBCircles = graph .get_nodes_into (node_B , Endpoint .CIRCLE )
467+
468+ for node_A in intoBCircles :
469+ found_paths_between_AB = []
470+ if graph .get_endpoint (node_B , node_A ) != Endpoint .CIRCLE :
471+ continue
472+ else :
473+ # Check if there is an uncovered circle path between A and B (A o-o C .. D o-o B)
474+ # s.t. A is not adjacent to D and B is not adjacent to C
475+ a_node_idx = graph .node_map [node_A ]
476+ b_node_idx = graph .node_map [node_B ]
477+ a_adj_nodes = graph .get_adjacent_nodes (node_A )
478+ b_adj_nodes = graph .get_adjacent_nodes (node_B )
479+
480+ # get the adjacent nodes with circle edges of A and B
481+ a_circle_adj_nodes_set = [node for node in a_adj_nodes if graph .node_map [node ] != a_node_idx and graph .node_map [node ]!= b_node_idx
482+ and graph .get_endpoint (node , node_A ) == Endpoint .CIRCLE and graph .get_endpoint (node_A , node ) == Endpoint .CIRCLE ]
483+ b_circle_adj_nodes_set = [node for node in b_adj_nodes if graph .node_map [node ] != a_node_idx and graph .node_map [node ]!= b_node_idx
484+ and graph .get_endpoint (node , node_B ) == Endpoint .CIRCLE and graph .get_endpoint (node_B , node ) == Endpoint .CIRCLE ]
485+
486+ # get the adjacent nodes with circle edges of A and B that is non adjacent to B and A, respectively
487+ for node_C in a_circle_adj_nodes_set :
488+ if graph .is_adjacent_to (node_B , node_C ):
489+ continue
490+ for node_D in b_circle_adj_nodes_set :
491+ if graph .is_adjacent_to (node_A , node_D ):
492+ continue
493+ paths = GetUncoveredCirclePath (node_from = node_C , node_to = node_D , G = graph , exclude_node = [node_A , node_B ]) # get the uncovered circle path between C and D, excluding A and B
494+ found_paths_between_AB .append (paths )
495+
496+ # Orient the uncovered circle path between A and B
497+ for paths in found_paths_between_AB :
498+ for path in paths :
499+ changeFlag = True
500+ if verbose :
501+ print ("Find uncovered circle path between A and B: " + graph .get_edge (node_A , node_B ).__str__ ())
502+ edge = graph .get_edge (node_A , node_B )
503+ graph .remove_edge (edge )
504+ graph .add_edge (Edge (node_A , node_B , Endpoint .TAIL , Endpoint .TAIL ))
505+ orient_on_path_helper (path , node_A , node_B )
506+
507+ return changeFlag
508+
509+ def ruleR6 (graph : Graph , changeFlag : bool ,
510+ verbose : bool = False ) -> bool :
511+ nodes = graph .get_nodes ()
512+
513+ for node_B in nodes :
514+ # Find A - B
515+ intoBTails = graph .get_nodes_into (node_B , Endpoint .TAIL )
516+ exist = False
517+ for node_A in intoBTails :
518+ if graph .get_endpoint (node_B , node_A ) == Endpoint .TAIL :
519+ exist = True
520+ if not exist :
521+ continue
522+ # Find B o-*C
523+ intoBCircles = graph .get_nodes_into (node_B , Endpoint .CIRCLE )
524+ for node_C in intoBCircles :
525+ changeFlag = True
526+ edge = graph .get_edge (node_B , node_C )
527+ graph .remove_edge (edge )
528+ graph .add_edge (Edge (node_B , node_C , Endpoint .TAIL , edge .get_proximal_endpoint (node_C )))
529+ if verbose :
530+ print ("Orienting edge by rule 6): " + graph .get_edge (node_B , node_C ).__str__ ())
531+
532+ return changeFlag
533+
534+
535+ def ruleR7 (graph : Graph , changeFlag : bool ,
536+ verbose : bool = False ) -> bool :
537+ nodes = graph .get_nodes ()
538+
539+ for node_B in nodes :
540+ # Find A -o B
541+ intoBCircles = graph .get_nodes_into (node_B , Endpoint .CIRCLE )
542+ node_A_list = [node for node in intoBCircles if graph .get_endpoint (node_B , node ) == Endpoint .TAIL ]
543+
544+ # Find B o-*C
545+ for node_C in intoBCircles :
546+ # pdb.set_trace()
547+ for node_A in node_A_list :
548+ # pdb.set_trace()
549+ if not graph .is_adjacent_to (node_A , node_C ):
550+ changeFlag = True
551+ edge = graph .get_edge (node_B , node_C )
552+ graph .remove_edge (edge )
553+ graph .add_edge (Edge (node_B , node_C , Endpoint .TAIL , edge .get_proximal_endpoint (node_C )))
554+ if verbose :
555+ print ("Orienting edge by rule 7): " + graph .get_edge (node_B , node_C ).__str__ ())
556+ return changeFlag
374557
375558def getPath (node_c : Node , previous ) -> List [Node ]:
376559 l = []
@@ -544,9 +727,8 @@ def ruleR4B(graph: Graph, maxPathLength: int, data: ndarray, independence_test_m
544727
545728
546729
547- def rule8 (graph : Graph , nodes : List [Node ]):
548- nodes = graph .get_nodes ()
549- changeFlag = False
730+ def rule8 (graph : Graph , nodes : List [Node ], changeFlag ):
731+ nodes = graph .get_nodes () if nodes is None else nodes
550732 for node_B in nodes :
551733 adj = graph .get_adjacent_nodes (node_B )
552734 if len (adj ) < 2 :
@@ -601,9 +783,9 @@ def find_possible_children(graph: Graph, parent_node, en_nodes=None):
601783
602784 return potential_child_nodes
603785
604- def rule9 (graph : Graph , nodes : List [Node ]):
605- changeFlag = False
606- nodes = graph .get_nodes ()
786+ def rule9 (graph : Graph , nodes : List [Node ], changeFlag ):
787+ # changeFlag = False
788+ nodes = graph .get_nodes () if nodes is None else nodes
607789 for node_C in nodes :
608790 intoCArrows = graph .get_nodes_into (node_C , Endpoint .ARROW )
609791 for node_A in intoCArrows :
@@ -629,8 +811,8 @@ def rule9(graph: Graph, nodes: List[Node]):
629811 return changeFlag
630812
631813
632- def rule10 (graph : Graph ):
633- changeFlag = False
814+ def rule10 (graph : Graph , changeFlag ):
815+ # changeFlag = False
634816 nodes = graph .get_nodes ()
635817 for node_C in nodes :
636818 intoCArrows = graph .get_nodes_into (node_C , Endpoint .ARROW )
@@ -895,6 +1077,7 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
8951077 graph , sep_sets , test_results = fas (dataset , nodes , independence_test_method = independence_test_method , alpha = alpha ,
8961078 knowledge = background_knowledge , depth = depth , verbose = verbose , show_progress = show_progress )
8971079
1080+ # pdb.set_trace()
8981081 reorientAllWith (graph , Endpoint .CIRCLE )
8991082
9001083 rule0 (graph , nodes , sep_sets , background_knowledge , verbose )
@@ -925,12 +1108,22 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
9251108 if verbose :
9261109 print ("Epoch" )
9271110
1111+ # rule 5
1112+ change_flag = ruleR5 (graph , change_flag , verbose )
1113+
1114+ # rule 6
1115+ change_flag = ruleR6 (graph , change_flag , verbose )
1116+
1117+ # rule 7
1118+ change_flag = ruleR7 (graph , change_flag , verbose )
1119+
9281120 # rule 8
929- change_flag = rule8 (graph ,nodes )
1121+ change_flag = rule8 (graph ,nodes , change_flag )
1122+
9301123 # rule 9
931- change_flag = rule9 (graph , nodes )
1124+ change_flag = rule9 (graph , nodes , change_flag )
9321125 # rule 10
933- change_flag = rule10 (graph )
1126+ change_flag = rule10 (graph , change_flag )
9341127
9351128 graph .set_pag (True )
9361129
0 commit comments