Skip to content

Commit a681ae8

Browse files
authored
Improve SkeletonDiscovery.py
1 parent 0f54fb4 commit a681ae8

File tree

1 file changed

+28
-24
lines changed

1 file changed

+28
-24
lines changed

causallearn/utils/PCUtils/SkeletonDiscovery.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,50 +34,54 @@ def skeleton_discovery(data, alpha, indep_test, stable=True, background_knowledg
3434
cg.data = data
3535
cg.set_ind_test(indep_test)
3636

37-
node_ids = range(no_of_var)
38-
pair_of_variables = list(permutations(node_ids, 2))
39-
4037
depth = -1
4138
while cg.max_degree() - 1 > depth:
4239
depth += 1
4340
edge_removal = []
44-
for (x, y) in pair_of_variables:
41+
for x in range(no_of_var):
4542
Neigh_x = cg.neighbors(x)
46-
if y not in Neigh_x:
43+
if len(Neigh_x) < depth - 1:
4744
continue
48-
else:
49-
Neigh_x = np.delete(Neigh_x, np.where(Neigh_x == y))
50-
51-
if len(Neigh_x) >= depth:
52-
for S in combinations(Neigh_x, depth):
53-
p = cg.ci_test(x, y, S)
54-
if p > alpha or (background_knowledge is not None and (
55-
background_knowledge.is_forbidden(cg.G.nodes[x],
56-
cg.G.nodes[y]) and background_knowledge.is_forbidden(
57-
cg.G.nodes[y], cg.G.nodes[x]))):
58-
if p > alpha:
59-
print('%d ind %d | %s with p-value %f\n' % (x, y, S, p))
60-
else:
61-
print('%d ind %d | %s with background knowledge\n' % (x, y, S))
62-
63-
if not stable: # Unstable: Remove x---y right away
45+
for y in Neigh_x:
46+
Neigh_x_noy = np.delete(Neigh_x, np.where(Neigh_x == y))
47+
for S in combinations(Neigh_x_noy, depth):
48+
if background_knowledge is not None and (
49+
background_knowledge.is_forbidden(cg.G.nodes[x], cg.G.nodes[y])
50+
and background_knowledge.is_forbidden(cg.G.nodes[y], cg.G.nodes[x])):
51+
print('%d ind %d | %s with background knowledge\n' % (x, y, S))
52+
if not stable:
6453
edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y])
6554
if edge1 is not None:
6655
cg.G.remove_edge(edge1)
6756
edge2 = cg.G.get_edge(cg.G.nodes[y], cg.G.nodes[x])
6857
if edge2 is not None:
6958
cg.G.remove_edge(edge2)
70-
else: # Stable: x---y will be removed only
59+
else:
7160
edge_removal.append((x, y)) # after all conditioning sets at
7261
edge_removal.append((y, x)) # depth l have been considered
7362
append_value(cg.sepset, x, y, S)
7463
append_value(cg.sepset, y, x, S)
7564
break
7665
else:
77-
if p <= alpha:
66+
p = cg.ci_test(x, y, S)
67+
if p > alpha:
68+
print('%d ind %d | %s with p-value %f\n' % (x, y, S, p))
69+
if not stable:
70+
edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y])
71+
if edge1 is not None:
72+
cg.G.remove_edge(edge1)
73+
edge2 = cg.G.get_edge(cg.G.nodes[y], cg.G.nodes[x])
74+
if edge2 is not None:
75+
cg.G.remove_edge(edge2)
76+
else:
77+
edge_removal.append((x, y)) # after all conditioning sets at
78+
edge_removal.append((y, x)) # depth l have been considered
79+
append_value(cg.sepset, x, y, S)
80+
append_value(cg.sepset, y, x, S)
81+
break
82+
else:
7883
print('%d dep %d | %s with p-value %f\n' % (x, y, S, p))
7984

80-
8185
for (x, y) in list(set(edge_removal)):
8286
edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y])
8387
if edge1 is not None:

0 commit comments

Comments
 (0)