1010from causallearn .graph .Endpoint import Endpoint
1111from causallearn .graph .GeneralGraph import GeneralGraph
1212from causallearn .graph .Node import Node
13- from causallearn .search .ConstraintBased .FCI import rule0 , rulesR1R2cycle , ruleR3 , ruleR4B
13+ from causallearn .search .ConstraintBased .FCI import rule0 , rulesR1R2cycle , ruleR3 , ruleR4B , ruleR5 , ruleR6 , ruleR7 , rule8 , rule9 , rule10
1414from causallearn .utils .cit import CIT , d_separation
1515
16- def dag2pag (dag : Dag , islatent : List [Node ]) -> GeneralGraph :
16+
17+ def dag2pag (dag : Dag , islatent : List [Node ], isselection : List [Node ] = []) -> GeneralGraph :
1718 """
1819 Convert a DAG to its corresponding PAG
1920 Parameters
@@ -27,8 +28,8 @@ def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
2728 dg = nx .DiGraph ()
2829 true_dag = nx .DiGraph ()
2930 nodes = dag .get_nodes ()
30- observed_nodes = list (set (nodes ) - set (islatent ))
31- mod_nodes = observed_nodes + islatent
31+ observed_nodes = list (set (nodes ) - set (islatent ) - set ( isselection ) )
32+ mod_nodes = observed_nodes + islatent + isselection
3233 nodes = dag .get_nodes ()
3334 nodes_ids = {node : i for i , node in enumerate (nodes )}
3435 mod_nodeids = {node : i for i , node in enumerate (mod_nodes )}
@@ -65,7 +66,7 @@ def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
6566 for Z in combinations (observed_nodes , l ):
6667 if nodex in Z or nodey in Z :
6768 continue
68- if d_separated (dg , {nodes_ids [nodex ]}, {nodes_ids [nodey ]}, set (nodes_ids [z ] for z in Z )):
69+ if d_separated (dg , {nodes_ids [nodex ]}, {nodes_ids [nodey ]}, set (nodes_ids [z ] for z in Z ) | set ([ nodes_ids [ s ] for s in isselection ]) ):
6970 if edge :
7071 PAG .remove_edge (edge )
7172 sepset [(nodes_ids [nodex ], nodes_ids [nodey ])] |= set (Z )
@@ -105,6 +106,13 @@ def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
105106 change_flag = ruleR4B (PAG , - 1 , data , independence_test_method , 0.05 , sep_sets = sepset_reindexed ,
106107 change_flag = change_flag ,
107108 bk = None , verbose = False )
109+ change_flag = ruleR5 (PAG , changeFlag = change_flag , verbose = True )
110+ change_flag = ruleR6 (PAG , changeFlag = change_flag )
111+ change_flag = ruleR7 (PAG , changeFlag = change_flag )
112+ change_flag = rule8 (PAG , nodes = observed_nodes , changeFlag = change_flag )
113+ change_flag = rule9 (PAG , nodes = observed_nodes , changeFlag = change_flag )
114+ change_flag = rule10 (PAG , changeFlag = change_flag )
115+
108116 return PAG
109117
110118
0 commit comments