Skip to content

Commit 99abff0

Browse files
committed
Fix GIN
1 parent 21ccd97 commit 99abff0

File tree

2 files changed

+189
-5
lines changed

2 files changed

+189
-5
lines changed

causallearn/search/HiddenCausal/GIN/GIN.py

Lines changed: 165 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,154 @@
55
From DMIRLab: https://dmir.gdut.edu.cn/
66
'''
77

8+
import random
89
from collections import deque
910
from itertools import combinations
1011

1112
import numpy as np
13+
from scipy.stats import chi2
1214

1315
from causallearn.graph.GeneralGraph import GeneralGraph
1416
from causallearn.graph.GraphNode import GraphNode
1517
from causallearn.graph.NodeType import NodeType
18+
from causallearn.graph.Edge import Edge
19+
from causallearn.graph.Endpoint import Endpoint
1620
from causallearn.search.FCMBased.lingam.hsic import hsic_test_gamma
21+
from causallearn.utils.cit import kci
1722

1823

19-
def GIN(data):
24+
def fisher_test(pvals):
25+
pvals = [pval if pval >= 1e-5 else 1e-5 for pval in pvals]
26+
return min(pvals)
27+
# fisher_stat = -2.0 * np.sum(np.log(pvals))
28+
# return 1 - chi2.cdf(fisher_stat, 2 * len(pvals))
29+
30+
31+
def GIN(data, indep_test=kci, alpha=0.05):
32+
'''
33+
Learning causal structure of Latent Variables for Linear Non-Gaussian Latent Variable Model
34+
with Generalized Independent Noise Condition
35+
36+
Parameters
37+
----------
38+
data : numpy ndarray
39+
data set
40+
indep_test : callable, default=kci
41+
the function of the independence test being used
42+
alpha : float, default=0.05
43+
desired significance level of independence tests (p_value) in (0,1)
44+
Returns
45+
-------
46+
G : general graph
47+
causal graph
48+
K : list
49+
causal order
50+
'''
51+
n = data.shape[1]
52+
cov = np.cov(data.T)
53+
54+
var_set = set(range(n))
55+
cluster_size = 2
56+
clusters_list = []
57+
while cluster_size < len(var_set):
58+
tmp_clusters_list = []
59+
for cluster in combinations(var_set, cluster_size):
60+
remain_var_set = var_set - set(cluster)
61+
e = cal_e_with_gin(data, cov, list(cluster), list(remain_var_set))
62+
pvals = []
63+
tmp_data = np.concatenate([data[:, list(remain_var_set)], e.reshape(-1, 1)], axis=1)
64+
for z in range(len(remain_var_set)):
65+
pvals.append(indep_test(tmp_data, z, - 1))
66+
fisher_pval = fisher_test(pvals)
67+
if fisher_pval >= alpha:
68+
tmp_clusters_list.append(cluster)
69+
tmp_clusters_list = merge_overlaping_cluster(tmp_clusters_list)
70+
clusters_list = clusters_list + tmp_clusters_list
71+
for cluster in tmp_clusters_list:
72+
var_set -= set(cluster)
73+
cluster_size += 1
74+
75+
K = []
76+
updated = True
77+
while updated:
78+
updated = False
79+
X = []
80+
Z = []
81+
for cluster_k in K:
82+
cluster_k1, cluster_k2 = array_split(cluster_k, 2)
83+
X += cluster_k1
84+
Z += cluster_k2
85+
86+
for i, cluster_i in enumerate(clusters_list):
87+
is_root = True
88+
random.shuffle(cluster_i)
89+
cluster_i1, cluster_i2 = array_split(cluster_i, 2)
90+
for j, cluster_j in enumerate(clusters_list):
91+
if i == j:
92+
continue
93+
random.shuffle(cluster_j)
94+
cluster_j1, cluster_j2 = array_split(cluster_j, 2)
95+
e = cal_e_with_gin(data, cov, X + cluster_i1 + cluster_j1, Z + cluster_i2)
96+
pvals = []
97+
tmp_data = np.concatenate([data[:, Z + cluster_i2], e.reshape(-1, 1)], axis=1)
98+
for z in range(len(Z + cluster_i2)):
99+
pvals.append(indep_test(tmp_data, z, - 1))
100+
fisher_pval = fisher_test(pvals)
101+
if fisher_pval < alpha:
102+
is_root = False
103+
break
104+
if is_root:
105+
K.append(cluster_i)
106+
clusters_list.remove(cluster_i)
107+
updated = True
108+
break
109+
110+
G = GeneralGraph([])
111+
for var in var_set:
112+
o_node = GraphNode(f"X{var + 1}")
113+
G.add_node(o_node)
114+
115+
latent_id = 1
116+
l_nodes = []
117+
118+
for cluster in K:
119+
l_node = GraphNode(f"L{latent_id}")
120+
l_node.set_node_type(NodeType.LATENT)
121+
G.add_node(l_node)
122+
for l in l_nodes:
123+
G.add_directed_edge(l, l_node)
124+
l_nodes.append(l_node)
125+
126+
for o in cluster:
127+
o_node = GraphNode(f"X{o + 1}")
128+
G.add_node(o_node)
129+
G.add_directed_edge(l_node, o_node)
130+
latent_id += 1
131+
132+
undirected_l_nodes = []
133+
134+
for cluster in clusters_list:
135+
l_node = GraphNode(f"L{latent_id}")
136+
l_node.set_node_type(NodeType.LATENT)
137+
G.add_node(l_node)
138+
for l in l_nodes:
139+
G.add_directed_edge(l, l_node)
140+
141+
for l in undirected_l_nodes:
142+
G.add_edge(Edge(l, l_node, Endpoint.TAIL, Endpoint.TAIL))
143+
144+
undirected_l_nodes.append(l_node)
145+
146+
for o in cluster:
147+
o_node = GraphNode(f"X{o + 1}")
148+
G.add_node(o_node)
149+
G.add_directed_edge(l_node, o_node)
150+
latent_id += 1
151+
152+
return G, K
153+
154+
155+
def GIN_MI(data):
20156
'''
21157
Learning causal structure of Latent Variables for Linear Non-Gaussian Latent Variable Model
22158
with Generalized Independent Noise Condition
@@ -81,6 +217,13 @@ def GIN(data):
81217
return G, K
82218

83219

220+
def cal_e_with_gin(data, cov, X, Z):
221+
cov_m = cov[np.ix_(Z, X)]
222+
_, _, v = np.linalg.svd(cov_m)
223+
omega = v.T[:, -1]
224+
return np.dot(omega, data[:, X].T)
225+
226+
84227
def cal_dep_for_gin(data, cov, X, Z):
85228
'''
86229
Calculate the statistics of dependence via Generalized Independent Noise Condition
@@ -96,10 +239,8 @@ def cal_dep_for_gin(data, cov, X, Z):
96239
-------
97240
sta : test statistic
98241
'''
99-
cov_m = cov[np.ix_(Z, X)]
100-
_, _, v = np.linalg.svd(cov_m)
101-
omega = v.T[:, -1]
102-
e_xz = np.dot(omega, data[:, X].T)
242+
243+
e_xz = cal_e_with_gin(data, cov, X, Z)
103244

104245
sta = 0
105246
for i in Z:
@@ -160,6 +301,8 @@ def _get_all_elements(S):
160301
# merging cluster
161302
def merge_overlaping_cluster(cluster_list):
162303
v_labels = _get_all_elements(cluster_list)
304+
if len(v_labels) == 0:
305+
return []
163306
cluster_dict = {i: -1 for i in v_labels}
164307
cluster_b = {i: [] for i in v_labels}
165308
cluster_len = 0
@@ -197,3 +340,20 @@ def merge_overlaping_cluster(cluster_list):
197340
cluster[cluster_dict[i]].append(i)
198341

199342
return cluster
343+
344+
345+
def array_split(x, k):
346+
x_len = len(x)
347+
# div_points = []
348+
sub_arys = []
349+
start = 0
350+
section_len = x_len // k
351+
extra = x_len % k
352+
for i in range(extra):
353+
sub_arys.append(x[start:start + section_len + 1])
354+
start = start + section_len + 1
355+
356+
for i in range(k - extra):
357+
sub_arys.append(x[start:start + section_len])
358+
start = start + section_len
359+
return sub_arys

tests/TestGIN.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import random
12
import sys
23

34
sys.path.append("")
@@ -45,3 +46,26 @@ def test_case2(self):
4546
data = (data - np.mean(data, axis=0)) / np.std(data, axis=0)
4647
g, k = GIN(data)
4748
print(g, k)
49+
50+
def test_case3(self):
51+
sample_size = 1000
52+
random.seed(42)
53+
np.random.seed(42)
54+
L1 = np.random.uniform(-1, 1, size=sample_size) ** 5
55+
L2 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(-1, 1, size=sample_size) ** 5
56+
L3 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5
57+
L4 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(0.5, 2.0) * L3 + np.random.uniform(-1, 1, size=sample_size) ** 5
58+
59+
X1 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5
60+
X2 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5
61+
X3 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5
62+
X4 = np.random.uniform(0.5, 2.0) * L1 + np.random.uniform(0.5, 2.0) * L2 + np.random.uniform(-1, 1, size=sample_size) ** 5
63+
X5 = np.random.uniform(0.5, 2.0) * L3 + np.random.uniform(-1, 1, size=sample_size) ** 5
64+
X6 = np.random.uniform(0.5, 2.0) * L3 + np.random.uniform(-1, 1, size=sample_size) ** 5
65+
X7 = np.random.uniform(0.5, 2.0) * L4 + np.random.uniform(-1, 1, size=sample_size) ** 5
66+
X8 = np.random.uniform(0.5, 2.0) * L4 + np.random.uniform(-1, 1, size=sample_size) ** 5
67+
68+
data = np.array([X1, X2, X3, X4, X5, X6, X7, X8]).T
69+
data = (data - np.mean(data, axis=0)) / np.std(data, axis=0)
70+
g, k = GIN(data)
71+
print(g, k)

0 commit comments

Comments
 (0)