Skip to content

Commit a6479bb

Browse files
committed
Refactored DAG2PAG algorithm
Signed-off-by: ZhiyiHuang <[email protected]>
1 parent 13b06a2 commit a6479bb

File tree

1 file changed

+41
-43
lines changed

1 file changed

+41
-43
lines changed

causallearn/utils/DAG2PAG.py

Lines changed: 41 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
from itertools import combinations, permutations
22
from typing import List
33

4+
import numpy as np
45
import networkx as nx
6+
from networkx.algorithms import d_separated
57

68
from causallearn.graph.Dag import Dag
79
from causallearn.graph.Edge import Edge
810
from causallearn.graph.Endpoint import Endpoint
911
from causallearn.graph.GeneralGraph import GeneralGraph
1012
from causallearn.graph.Node import Node
11-
from causallearn.search.ConstraintBased.FCI import ruleR3, rulesR1R2cycle
12-
13+
from causallearn.search.ConstraintBased.FCI import rule0, rulesR1R2cycle, ruleR3, ruleR4B
14+
from causallearn.utils.cit import CIT, d_separation
1315

1416
def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
1517
"""
@@ -22,15 +24,29 @@ def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
2224
-------
2325
PAG : Partial Ancestral Graph
2426
"""
25-
udg = nx.Graph()
27+
dg = nx.DiGraph()
28+
true_dag = nx.DiGraph()
29+
nodes = dag.get_nodes()
30+
observed_nodes = list(set(nodes) - set(islatent))
31+
mod_nodes = observed_nodes + islatent
2632
nodes = dag.get_nodes()
2733
nodes_ids = {node: i for i, node in enumerate(nodes)}
34+
mod_nodeids = {node: i for i, node in enumerate(mod_nodes)}
35+
2836
n = len(nodes)
37+
dg.add_nodes_from(range(n))
38+
true_dag.add_nodes_from(range(n))
39+
2940
for x, y in combinations(range(n), 2):
30-
if dag.get_edge(nodes[x], nodes[y]):
31-
udg.add_edge(x, y)
41+
edge = dag.get_edge(nodes[x], nodes[y])
42+
if edge:
43+
if edge.get_endpoint2() == Endpoint.ARROW:
44+
dg.add_edge(nodes_ids[edge.get_node1()], nodes_ids[edge.get_node2()])
45+
true_dag.add_edge(mod_nodeids[edge.get_node1()], mod_nodeids[edge.get_node2()])
46+
else:
47+
dg.add_edge(nodes_ids[edge.get_node2()], nodes_ids[edge.get_node1()])
48+
true_dag.add_edge(mod_nodeids[edge.get_node1()], mod_nodeids[edge.get_node2()])
3249

33-
observed_nodes = list(set(nodes) - set(islatent))
3450

3551
PAG = GeneralGraph(observed_nodes)
3652
for nodex, nodey in combinations(observed_nodes, 2):
@@ -41,43 +57,19 @@ def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
4157

4258
sepset = {(nodex, nodey): set() for nodex, nodey in permutations(observed_nodes, 2)}
4359

44-
for nodex, nodey in combinations(observed_nodes, 2):
45-
if nodex in islatent:
46-
continue
47-
if nodey in islatent:
48-
continue
49-
all_paths = nx.all_simple_paths(udg, nodes_ids[nodex], nodes_ids[nodey])
50-
noncolider_path = []
51-
is_connected = False
52-
for path in all_paths:
53-
path_sep = True
54-
has_nonlatent = False
55-
for i in range(1, len(path) - 1):
56-
if nodes[path[i]] in observed_nodes:
57-
has_nonlatent = True
58-
has_collider = is_endpoint(dag.get_edge(nodes[path[i - 1]], nodes[path[i]]), nodes[path[i]],
59-
Endpoint.ARROW) and \
60-
is_endpoint(dag.get_edge(nodes[path[i + 1]], nodes[path[i]]), nodes[path[i]],
61-
Endpoint.ARROW)
62-
if has_collider:
63-
path_sep = False
64-
if not path_sep:
65-
continue
66-
if has_nonlatent:
67-
noncolider_path.append(path)
68-
else:
69-
is_connected = True
70-
break
71-
if not is_connected:
60+
for l in range(0, len(observed_nodes) - 1):
61+
for nodex, nodey in combinations(observed_nodes, 2):
7262
edge = PAG.get_edge(nodex, nodey)
73-
if edge:
74-
PAG.remove_edge(edge)
75-
for path in noncolider_path:
76-
for i in range(1, len(path) - 1):
77-
if nodes[path[i]] in islatent:
78-
continue
79-
sepset[(nodex, nodey)] |= {nodes[path[i]]}
80-
sepset[(nodey, nodex)] |= {nodes[path[i]]}
63+
if not edge:
64+
continue
65+
for Z in combinations(observed_nodes, l):
66+
if nodex in Z or nodey in Z:
67+
continue
68+
if d_separated(dg, {nodes_ids[nodex]}, {nodes_ids[nodey]}, set(nodes_ids[z] for z in Z)):
69+
if edge:
70+
PAG.remove_edge(edge)
71+
sepset[(nodex, nodey)] |= set(Z)
72+
sepset[(nodey, nodex)] |= set(Z)
8173

8274
for nodex, nodey in combinations(observed_nodes, 2):
8375
if PAG.get_edge(nodex, nodey):
@@ -99,13 +91,19 @@ def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
9991
mod_endpoint(edge_yz, nodez, Endpoint.ARROW)
10092
PAG.add_edge(edge_yz)
10193

94+
print()
10295
change_flag = True
10396

97+
data = np.empty(shape=(0, len(observed_nodes)))
98+
independence_test_method = CIT(data, method=d_separation, true_dag=true_dag)
99+
104100
while change_flag:
105101
change_flag = False
106102
change_flag = rulesR1R2cycle(PAG, None, change_flag, False)
107103
change_flag = ruleR3(PAG, sepset, None, change_flag, False)
108-
104+
change_flag = ruleR4B(PAG, -1, data, independence_test_method, 0.05, sep_sets=sepset,
105+
change_flag=change_flag,
106+
bk=None, verbose=False)
109107
return PAG
110108

111109

0 commit comments

Comments
 (0)