11from itertools import combinations , permutations
22from typing import List
33
4+ import numpy as np
45import networkx as nx
6+ from networkx .algorithms import d_separated
57
68from causallearn .graph .Dag import Dag
79from causallearn .graph .Edge import Edge
810from causallearn .graph .Endpoint import Endpoint
911from causallearn .graph .GeneralGraph import GeneralGraph
1012from causallearn .graph .Node import Node
11- from causallearn .search .ConstraintBased .FCI import ruleR3 , rulesR1R2cycle
12-
13+ from causallearn .search .ConstraintBased .FCI import rule0 , rulesR1R2cycle , ruleR3 , ruleR4B
14+ from causallearn . utils . cit import CIT , d_separation
1315
1416def dag2pag (dag : Dag , islatent : List [Node ]) -> GeneralGraph :
1517 """
@@ -22,15 +24,29 @@ def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
2224 -------
2325 PAG : Partial Ancestral Graph
2426 """
25- udg = nx .Graph ()
27+ dg = nx .DiGraph ()
28+ true_dag = nx .DiGraph ()
29+ nodes = dag .get_nodes ()
30+ observed_nodes = list (set (nodes ) - set (islatent ))
31+ mod_nodes = observed_nodes + islatent
2632 nodes = dag .get_nodes ()
2733 nodes_ids = {node : i for i , node in enumerate (nodes )}
34+ mod_nodeids = {node : i for i , node in enumerate (mod_nodes )}
35+
2836 n = len (nodes )
37+ dg .add_nodes_from (range (n ))
38+ true_dag .add_nodes_from (range (n ))
39+
2940 for x , y in combinations (range (n ), 2 ):
30- if dag .get_edge (nodes [x ], nodes [y ]):
31- udg .add_edge (x , y )
41+ edge = dag .get_edge (nodes [x ], nodes [y ])
42+ if edge :
43+ if edge .get_endpoint2 () == Endpoint .ARROW :
44+ dg .add_edge (nodes_ids [edge .get_node1 ()], nodes_ids [edge .get_node2 ()])
45+ true_dag .add_edge (mod_nodeids [edge .get_node1 ()], mod_nodeids [edge .get_node2 ()])
46+ else :
47+ dg .add_edge (nodes_ids [edge .get_node2 ()], nodes_ids [edge .get_node1 ()])
48+ true_dag .add_edge (mod_nodeids [edge .get_node1 ()], mod_nodeids [edge .get_node2 ()])
3249
33- observed_nodes = list (set (nodes ) - set (islatent ))
3450
3551 PAG = GeneralGraph (observed_nodes )
3652 for nodex , nodey in combinations (observed_nodes , 2 ):
@@ -41,43 +57,19 @@ def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
4157
4258 sepset = {(nodex , nodey ): set () for nodex , nodey in permutations (observed_nodes , 2 )}
4359
44- for nodex , nodey in combinations (observed_nodes , 2 ):
45- if nodex in islatent :
46- continue
47- if nodey in islatent :
48- continue
49- all_paths = nx .all_simple_paths (udg , nodes_ids [nodex ], nodes_ids [nodey ])
50- noncolider_path = []
51- is_connected = False
52- for path in all_paths :
53- path_sep = True
54- has_nonlatent = False
55- for i in range (1 , len (path ) - 1 ):
56- if nodes [path [i ]] in observed_nodes :
57- has_nonlatent = True
58- has_collider = is_endpoint (dag .get_edge (nodes [path [i - 1 ]], nodes [path [i ]]), nodes [path [i ]],
59- Endpoint .ARROW ) and \
60- is_endpoint (dag .get_edge (nodes [path [i + 1 ]], nodes [path [i ]]), nodes [path [i ]],
61- Endpoint .ARROW )
62- if has_collider :
63- path_sep = False
64- if not path_sep :
65- continue
66- if has_nonlatent :
67- noncolider_path .append (path )
68- else :
69- is_connected = True
70- break
71- if not is_connected :
60+ for l in range (0 , len (observed_nodes ) - 1 ):
61+ for nodex , nodey in combinations (observed_nodes , 2 ):
7262 edge = PAG .get_edge (nodex , nodey )
73- if edge :
74- PAG .remove_edge (edge )
75- for path in noncolider_path :
76- for i in range (1 , len (path ) - 1 ):
77- if nodes [path [i ]] in islatent :
78- continue
79- sepset [(nodex , nodey )] |= {nodes [path [i ]]}
80- sepset [(nodey , nodex )] |= {nodes [path [i ]]}
63+ if not edge :
64+ continue
65+ for Z in combinations (observed_nodes , l ):
66+ if nodex in Z or nodey in Z :
67+ continue
68+ if d_separated (dg , {nodes_ids [nodex ]}, {nodes_ids [nodey ]}, set (nodes_ids [z ] for z in Z )):
69+ if edge :
70+ PAG .remove_edge (edge )
71+ sepset [(nodex , nodey )] |= set (Z )
72+ sepset [(nodey , nodex )] |= set (Z )
8173
8274 for nodex , nodey in combinations (observed_nodes , 2 ):
8375 if PAG .get_edge (nodex , nodey ):
@@ -99,13 +91,19 @@ def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
9991 mod_endpoint (edge_yz , nodez , Endpoint .ARROW )
10092 PAG .add_edge (edge_yz )
10193
94+ print ()
10295 change_flag = True
10396
97+ data = np .empty (shape = (0 , len (observed_nodes )))
98+ independence_test_method = CIT (data , method = d_separation , true_dag = true_dag )
99+
104100 while change_flag :
105101 change_flag = False
106102 change_flag = rulesR1R2cycle (PAG , None , change_flag , False )
107103 change_flag = ruleR3 (PAG , sepset , None , change_flag , False )
108-
104+ change_flag = ruleR4B (PAG , - 1 , data , independence_test_method , 0.05 , sep_sets = sepset ,
105+ change_flag = change_flag ,
106+ bk = None , verbose = False )
109107 return PAG
110108
111109
0 commit comments