Skip to content

Commit 53ddcc0

Browse files
committed
DAG to a maximal PAG by adding rule5-10 and adding selection variables
1 parent a944e26 commit 53ddcc0

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

causallearn/utils/DAG2PAG.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
from causallearn.graph.Endpoint import Endpoint
1111
from causallearn.graph.GeneralGraph import GeneralGraph
1212
from causallearn.graph.Node import Node
13-
from causallearn.search.ConstraintBased.FCI import rule0, rulesR1R2cycle, ruleR3, ruleR4B
13+
from causallearn.search.ConstraintBased.FCI import rule0, rulesR1R2cycle, ruleR3, ruleR4B, ruleR5, ruleR6, ruleR7, rule8, rule9, rule10
1414
from causallearn.utils.cit import CIT, d_separation
1515

16-
def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
16+
17+
def dag2pag(dag: Dag, islatent: List[Node], isselection: List[Node] = []) -> GeneralGraph:
1718
"""
1819
Convert a DAG to its corresponding PAG
1920
Parameters
@@ -27,8 +28,8 @@ def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
2728
dg = nx.DiGraph()
2829
true_dag = nx.DiGraph()
2930
nodes = dag.get_nodes()
30-
observed_nodes = list(set(nodes) - set(islatent))
31-
mod_nodes = observed_nodes + islatent
31+
observed_nodes = list(set(nodes) - set(islatent) - set(isselection))
32+
mod_nodes = observed_nodes + islatent + isselection
3233
nodes = dag.get_nodes()
3334
nodes_ids = {node: i for i, node in enumerate(nodes)}
3435
mod_nodeids = {node: i for i, node in enumerate(mod_nodes)}
@@ -65,7 +66,7 @@ def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
6566
for Z in combinations(observed_nodes, l):
6667
if nodex in Z or nodey in Z:
6768
continue
68-
if d_separated(dg, {nodes_ids[nodex]}, {nodes_ids[nodey]}, set(nodes_ids[z] for z in Z)):
69+
if d_separated(dg, {nodes_ids[nodex]}, {nodes_ids[nodey]}, set(nodes_ids[z] for z in Z) | set([nodes_ids[s] for s in isselection])):
6970
if edge:
7071
PAG.remove_edge(edge)
7172
sepset[(nodes_ids[nodex], nodes_ids[nodey])] |= set(Z)
@@ -105,6 +106,13 @@ def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
105106
change_flag = ruleR4B(PAG, -1, data, independence_test_method, 0.05, sep_sets=sepset_reindexed,
106107
change_flag=change_flag,
107108
bk=None, verbose=False)
109+
change_flag = ruleR5(PAG, changeFlag=change_flag, verbose=True)
110+
change_flag = ruleR6(PAG, changeFlag=change_flag)
111+
change_flag = ruleR7(PAG, changeFlag=change_flag)
112+
change_flag = rule8(PAG, nodes=observed_nodes, changeFlag=change_flag)
113+
change_flag = rule9(PAG, nodes=observed_nodes, changeFlag=change_flag)
114+
change_flag = rule10(PAG, changeFlag=change_flag)
115+
108116
return PAG
109117

110118

tests/TestDAG2PAG.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,17 @@ def test_case3(self):
7272
print(pag)
7373
graphviz_pag = GraphUtils.to_pgv(pag)
7474
graphviz_pag.draw("pag.png", prog='dot', format='png')
75+
76+
def test_case_selection(self):
77+
nodes = []
78+
for i in range(5):
79+
nodes.append(GraphNode(str(i)))
80+
dag = Dag(nodes)
81+
dag.add_directed_edge(nodes[0], nodes[1])
82+
dag.add_directed_edge(nodes[1], nodes[2])
83+
dag.add_directed_edge(nodes[2], nodes[3])
84+
# Selection nodes
85+
dag.add_directed_edge(nodes[3], nodes[4])
86+
dag.add_directed_edge(nodes[0], nodes[4])
87+
pag = dag2pag(dag, islatent=[], isselection=[nodes[4]])
88+
print(pag)

0 commit comments

Comments
 (0)