55 From DMIRLab: https://dmir.gdut.edu.cn/
66'''
77
8+ import random
89from collections import deque
910from itertools import combinations
1011
1112import numpy as np
13+ from scipy .stats import chi2
1214
1315from causallearn .graph .GeneralGraph import GeneralGraph
1416from causallearn .graph .GraphNode import GraphNode
1517from causallearn .graph .NodeType import NodeType
18+ from causallearn .graph .Edge import Edge
19+ from causallearn .graph .Endpoint import Endpoint
1620from causallearn .search .FCMBased .lingam .hsic import hsic_test_gamma
21+ from causallearn .utils .cit import kci
1722
1823
19- def GIN (data ):
24+ def fisher_test (pvals ):
25+ pvals = [pval if pval >= 1e-5 else 1e-5 for pval in pvals ]
26+ return min (pvals )
27+ # fisher_stat = -2.0 * np.sum(np.log(pvals))
28+ # return 1 - chi2.cdf(fisher_stat, 2 * len(pvals))
29+
30+
31+ def GIN (data , indep_test = kci , alpha = 0.05 ):
32+ '''
33+ Learning causal structure of Latent Variables for Linear Non-Gaussian Latent Variable Model
34+ with Generalized Independent Noise Condition
35+
36+ Parameters
37+ ----------
38+ data : numpy ndarray
39+ data set
40+ indep_test : callable, default=kci
41+ the function of the independence test being used
42+ alpha : float, default=0.05
43+ desired significance level of independence tests (p_value) in (0,1)
44+ Returns
45+ -------
46+ G : general graph
47+ causal graph
48+ K : list
49+ causal order
50+ '''
51+ n = data .shape [1 ]
52+ cov = np .cov (data .T )
53+
54+ var_set = set (range (n ))
55+ cluster_size = 2
56+ clusters_list = []
57+ while cluster_size < len (var_set ):
58+ tmp_clusters_list = []
59+ for cluster in combinations (var_set , cluster_size ):
60+ remain_var_set = var_set - set (cluster )
61+ e = cal_e_with_gin (data , cov , list (cluster ), list (remain_var_set ))
62+ pvals = []
63+ tmp_data = np .concatenate ([data [:, list (remain_var_set )], e .reshape (- 1 , 1 )], axis = 1 )
64+ for z in range (len (remain_var_set )):
65+ pvals .append (indep_test (tmp_data , z , - 1 ))
66+ fisher_pval = fisher_test (pvals )
67+ if fisher_pval >= alpha :
68+ tmp_clusters_list .append (cluster )
69+ tmp_clusters_list = merge_overlaping_cluster (tmp_clusters_list )
70+ clusters_list = clusters_list + tmp_clusters_list
71+ for cluster in tmp_clusters_list :
72+ var_set -= set (cluster )
73+ cluster_size += 1
74+
75+ K = []
76+ updated = True
77+ while updated :
78+ updated = False
79+ X = []
80+ Z = []
81+ for cluster_k in K :
82+ cluster_k1 , cluster_k2 = array_split (cluster_k , 2 )
83+ X += cluster_k1
84+ Z += cluster_k2
85+
86+ for i , cluster_i in enumerate (clusters_list ):
87+ is_root = True
88+ random .shuffle (cluster_i )
89+ cluster_i1 , cluster_i2 = array_split (cluster_i , 2 )
90+ for j , cluster_j in enumerate (clusters_list ):
91+ if i == j :
92+ continue
93+ random .shuffle (cluster_j )
94+ cluster_j1 , cluster_j2 = array_split (cluster_j , 2 )
95+ e = cal_e_with_gin (data , cov , X + cluster_i1 + cluster_j1 , Z + cluster_i2 )
96+ pvals = []
97+ tmp_data = np .concatenate ([data [:, Z + cluster_i2 ], e .reshape (- 1 , 1 )], axis = 1 )
98+ for z in range (len (Z + cluster_i2 )):
99+ pvals .append (indep_test (tmp_data , z , - 1 ))
100+ fisher_pval = fisher_test (pvals )
101+ if fisher_pval < alpha :
102+ is_root = False
103+ break
104+ if is_root :
105+ K .append (cluster_i )
106+ clusters_list .remove (cluster_i )
107+ updated = True
108+ break
109+
110+ G = GeneralGraph ([])
111+ for var in var_set :
112+ o_node = GraphNode (f"X{ var + 1 } " )
113+ G .add_node (o_node )
114+
115+ latent_id = 1
116+ l_nodes = []
117+
118+ for cluster in K :
119+ l_node = GraphNode (f"L{ latent_id } " )
120+ l_node .set_node_type (NodeType .LATENT )
121+ G .add_node (l_node )
122+ for l in l_nodes :
123+ G .add_directed_edge (l , l_node )
124+ l_nodes .append (l_node )
125+
126+ for o in cluster :
127+ o_node = GraphNode (f"X{ o + 1 } " )
128+ G .add_node (o_node )
129+ G .add_directed_edge (l_node , o_node )
130+ latent_id += 1
131+
132+ undirected_l_nodes = []
133+
134+ for cluster in clusters_list :
135+ l_node = GraphNode (f"L{ latent_id } " )
136+ l_node .set_node_type (NodeType .LATENT )
137+ G .add_node (l_node )
138+ for l in l_nodes :
139+ G .add_directed_edge (l , l_node )
140+
141+ for l in undirected_l_nodes :
142+ G .add_edge (Edge (l , l_node , Endpoint .TAIL , Endpoint .TAIL ))
143+
144+ undirected_l_nodes .append (l_node )
145+
146+ for o in cluster :
147+ o_node = GraphNode (f"X{ o + 1 } " )
148+ G .add_node (o_node )
149+ G .add_directed_edge (l_node , o_node )
150+ latent_id += 1
151+
152+ return G , K
153+
154+
155+ def GIN_MI (data ):
20156 '''
21157 Learning causal structure of Latent Variables for Linear Non-Gaussian Latent Variable Model
22158 with Generalized Independent Noise Condition
@@ -81,6 +217,13 @@ def GIN(data):
81217 return G , K
82218
83219
220+ def cal_e_with_gin (data , cov , X , Z ):
221+ cov_m = cov [np .ix_ (Z , X )]
222+ _ , _ , v = np .linalg .svd (cov_m )
223+ omega = v .T [:, - 1 ]
224+ return np .dot (omega , data [:, X ].T )
225+
226+
84227def cal_dep_for_gin (data , cov , X , Z ):
85228 '''
86229 Calculate the statistics of dependence via Generalized Independent Noise Condition
@@ -96,10 +239,8 @@ def cal_dep_for_gin(data, cov, X, Z):
96239 -------
97240 sta : test statistic
98241 '''
99- cov_m = cov [np .ix_ (Z , X )]
100- _ , _ , v = np .linalg .svd (cov_m )
101- omega = v .T [:, - 1 ]
102- e_xz = np .dot (omega , data [:, X ].T )
242+
243+ e_xz = cal_e_with_gin (data , cov , X , Z )
103244
104245 sta = 0
105246 for i in Z :
@@ -160,6 +301,8 @@ def _get_all_elements(S):
160301# merging cluster
161302def merge_overlaping_cluster (cluster_list ):
162303 v_labels = _get_all_elements (cluster_list )
304+ if len (v_labels ) == 0 :
305+ return []
163306 cluster_dict = {i : - 1 for i in v_labels }
164307 cluster_b = {i : [] for i in v_labels }
165308 cluster_len = 0
@@ -197,3 +340,20 @@ def merge_overlaping_cluster(cluster_list):
197340 cluster [cluster_dict [i ]].append (i )
198341
199342 return cluster
343+
344+
345+ def array_split (x , k ):
346+ x_len = len (x )
347+ # div_points = []
348+ sub_arys = []
349+ start = 0
350+ section_len = x_len // k
351+ extra = x_len % k
352+ for i in range (extra ):
353+ sub_arys .append (x [start :start + section_len + 1 ])
354+ start = start + section_len + 1
355+
356+ for i in range (k - extra ):
357+ sub_arys .append (x [start :start + section_len ])
358+ start = start + section_len
359+ return sub_arys
0 commit comments