Skip to content

Commit 24687d9

Browse files
authored
Merge pull request #46 from MarkDana/cit_as_object
Rewrite CITests as a class && re-use covariance matrix for fisherz
2 parents 104b5dd + d9e87b0 commit 24687d9

File tree

10 files changed

+548
-939
lines changed

10 files changed

+548
-939
lines changed

causallearn/graph/GraphClass.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from causallearn.graph.Node import Node
1919
from causallearn.utils.GraphUtils import GraphUtils
2020
from causallearn.utils.PCUtils.Helper import list_union, powerset
21+
from causallearn.utils.cit import CIT
2122

2223

2324
class CausalGraph:
@@ -34,10 +35,7 @@ def __init__(self, no_of_var: int, node_names: List[str] | None = None):
3435
for i in range(no_of_var):
3536
for j in range(i + 1, no_of_var):
3637
self.G.add_edge(Edge(nodes[i], nodes[j], Endpoint.TAIL, Endpoint.TAIL))
37-
38-
self.data = None # store the data
39-
self.test = None # store the name of the conditional independence test
40-
self.corr_mat = None # store the correlation matrix of the data
38+
self.test: CIT | None = None
4139
self.sepset = np.empty((no_of_var, no_of_var), object) # store the collection of sepsets
4240
self.definite_UC = [] # store the list of definite unshielded colliders
4341
self.definite_non_UC = [] # store the list of definite unshielded non-colliders
@@ -47,35 +45,17 @@ def __init__(self, no_of_var: int, node_names: List[str] | None = None):
4745
self.nx_skel = nx.Graph() # store the undirected graph
4846
self.labels = {}
4947
self.prt_m = {} # store the parents of missingness indicators
50-
self.mvpc = False
51-
self.cardinalities = None # only works when self.data is discrete, i.e. self.test is chisq or gsq
52-
self.is_discrete = False
53-
self.citest_cache = dict()
54-
self.data_hash_key = None
55-
self.ci_test_hash_key = None
56-
57-
def set_ind_test(self, indep_test, mvpc=False):
48+
49+
50+
def set_ind_test(self, indep_test):
5851
"""Set the conditional independence test that will be used"""
59-
# assert name_of_test in ["Fisher_Z", "Chi_sq", "G_sq"]
60-
self.mvpc = mvpc
6152
self.test = indep_test
62-
self.ci_test_hash_key = hash(indep_test)
6353

6454
def ci_test(self, i: int, j: int, S) -> float:
6555
"""Define the conditional independence test"""
6656
# assert i != j and not i in S and not j in S
67-
if self.mvpc:
68-
return self.test(self.data, self.nx_skel, self.prt_m, i, j, S)
69-
70-
i, j = (i, j) if (i < j) else (j, i)
71-
ijS_key = (i, j, frozenset(S), self.data_hash_key, self.ci_test_hash_key)
72-
if ijS_key in self.citest_cache:
73-
return self.citest_cache[ijS_key]
74-
# if discrete, assert self.test is chisq or gsq, pass into cardinalities
75-
pValue = self.test(self.data, i, j, S, self.cardinalities) if self.is_discrete \
76-
else self.test(self.data, i, j, S)
77-
self.citest_cache[ijS_key] = pValue
78-
return pValue
57+
if self.test.method == 'mc_fisherz': return self.test(i, j, S, self.nx_skel, self.prt_m)
58+
return self.test(i, j, S)
7959

8060
def neighbors(self, i: int):
8161
"""Find the neighbors of node i in adjmat"""

causallearn/search/ConstraintBased/CDNOD.py

Lines changed: 9 additions & 261 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
1111
from causallearn.utils.PCUtils.BackgroundKnowledgeOrientUtils import orient_by_background_knowledge
1212
from causallearn.utils.cit import *
13+
from causallearn.search.ConstraintBased.PC import get_parent_missingness_pairs, skeleton_correction
1314

1415

15-
def cdnod(data: ndarray, c_indx: ndarray, alpha: float = 0.05, indep_test=fisherz, stable: bool = True,
16-
uc_rule: int = 0, uc_priority: int = 2, mvcdnod: bool = False, correction_name: str = 'MV_Crtn_Fisher_Z',
17-
background_knowledge: Optional[BackgroundKnowledge] = None, verbose: bool = False,
16+
def cdnod(data: ndarray, c_indx: ndarray, alpha: float=0.05, indep_test: str=fisherz, stable: bool=True,
17+
uc_rule: int=0, uc_priority: int=2, mvcdnod: bool=False, correction_name: str='MV_Crtn_Fisher_Z',
18+
background_knowledge: Optional[BackgroundKnowledge]=None, verbose: bool=False,
1819
show_progress: bool = True) -> CausalGraph:
1920
"""
2021
Causal discovery from nonstationary/heterogeneous data
@@ -43,7 +44,7 @@ def cdnod(data: ndarray, c_indx: ndarray, alpha: float = 0.05, indep_test=fisher
4344
show_progress=show_progress)
4445

4546

46-
def cdnod_alg(data: ndarray, alpha: float, indep_test, stable: bool, uc_rule: int, uc_priority: int,
47+
def cdnod_alg(data: ndarray, alpha: float, indep_test: str, stable: bool, uc_rule: int, uc_priority: int,
4748
background_knowledge: Optional[BackgroundKnowledge] = None, verbose: bool = False,
4849
show_progress: bool = True) -> CausalGraph:
4950
"""
@@ -84,6 +85,7 @@ def cdnod_alg(data: ndarray, alpha: float, indep_test, stable: bool, uc_rule: in
8485
8586
"""
8687
start = time.time()
88+
indep_test = CIT(data, indep_test)
8789
cg_1 = SkeletonDiscovery.skeleton_discovery(data, alpha, indep_test, stable)
8890

8991
# orient the direction from c_indx to X, if there is an edge between c_indx and X
@@ -124,7 +126,7 @@ def cdnod_alg(data: ndarray, alpha: float, indep_test, stable: bool, uc_rule: in
124126
return cg
125127

126128

127-
def mvcdnod_alg(data: ndarray, alpha: float, indep_test, correction_name: str, stable: bool, uc_rule: int,
129+
def mvcdnod_alg(data: ndarray, alpha: float, indep_test: str, correction_name: str, stable: bool, uc_rule: int,
128130
uc_priority: int, verbose: bool, show_progress: bool) -> CausalGraph:
129131
"""
130132
:param data: data set (numpy ndarray)
@@ -154,9 +156,9 @@ def mvcdnod_alg(data: ndarray, alpha: float, indep_test, correction_name: str, s
154156
"""
155157

156158
start = time.time()
157-
159+
indep_test = CIT(data, indep_test)
158160
## Step 1: detect the direct causes of missingness indicators
159-
prt_m = get_prt_mpairs(data, alpha, indep_test, stable)
161+
prt_m = get_parent_missingness_pairs(data, alpha, indep_test, stable)
160162
# print('Finish detecting the parents of missingness indicators. ')
161163

162164
## Step 2:
@@ -204,257 +206,3 @@ def mvcdnod_alg(data: ndarray, alpha: float, indep_test, correction_name: str, s
204206
cg.PC_elapsed = end - start
205207

206208
return cg
207-
208-
209-
#######################################################################################################################
210-
## *********** Functions for Step 1 ***********
211-
def get_prt_mpairs(data: ndarray, alpha: float, indep_test, stable: bool = True) -> Dict[str, list]:
212-
"""
213-
Detect the parents of missingness indicators
214-
If a missingness indicator has no parent, it will not be included in the result
215-
:param data: data set (numpy ndarray)
216-
:param alpha: desired significance level in (0, 1) (float)
217-
:param indep_test: name of the test-wise deletion independence test being used
218-
- "MV_Fisher_Z": Fisher's Z conditional independence test
219-
- "MV_G_sq": G-squared conditional independence test (TODO: under development)
220-
:param stable: run stabilized skeleton discovery if True (default = True)
221-
:return:
222-
cg: a CausalGraph object
223-
"""
224-
prt_m = {'prt': [], 'm': []}
225-
226-
## Get the index of missingness indicators
227-
m_indx = get_mindx(data)
228-
229-
## Get the index of parents of missingness indicators
230-
# If the missingness indicator has no parent, then it will not be collected in prt_m
231-
for r in m_indx:
232-
prt_r = detect_parent(r, data, alpha, indep_test, stable)
233-
if isempty(prt_r):
234-
pass
235-
else:
236-
prt_m['prt'].append(prt_r)
237-
prt_m['m'].append(r)
238-
return prt_m
239-
240-
241-
def isempty(prt_r: ndarray) -> bool:
242-
"""Test whether the parent of a missingness indicator is empty"""
243-
return len(prt_r) == 0
244-
245-
246-
def get_mindx(data: ndarray) -> List[int]:
247-
"""Detect the parents of missingness indicators
248-
:param data: data set (numpy ndarray)
249-
:return:
250-
m_indx: list, the index of missingness indicators
251-
"""
252-
253-
m_indx = []
254-
_, ncol = np.shape(data)
255-
for i in range(ncol):
256-
if np.isnan(data[:, i]).any():
257-
m_indx.append(i)
258-
return m_indx
259-
260-
261-
def detect_parent(r: int, data_: ndarray, alpha: float, indep_test, stable: bool = True) -> ndarray:
262-
"""Detect the parents of a missingness indicator
263-
:param r: the missingness indicator
264-
:param data_: data set (numpy ndarray)
265-
:param alpha: desired significance level in (0, 1) (float)
266-
:param indep_test: name of the test-wise deletion independence test being used
267-
- "MV_Fisher_Z": Fisher's Z conditional independence test
268-
- "MV_G_sq": G-squared conditional independence test (TODO: under development)
269-
:param stable: run stabilized skeleton discovery if True (default = True)
270-
: return:
271-
prt: parent of the missingness indicator, r
272-
"""
273-
## TODO: in the test-wise deletion CI test, if test between a binary and a continuous variable,
274-
# there can be the case where the binary variable only take one value after deletion.
275-
# It is because the assumption is violated.
276-
277-
## *********** Adaptation 0 ***********
278-
# For avoid changing the original data
279-
data = data_.copy()
280-
## *********** End ***********
281-
282-
assert type(data) == np.ndarray
283-
assert 0 < alpha < 1
284-
285-
## *********** Adaptation 1 ***********
286-
# data
287-
## Replace the variable r with its missingness indicator
288-
## If r is not a missingness indicator, return [].
289-
data[:, r] = np.isnan(data[:, r]).astype(float) # True is missing; false is not missing
290-
if sum(data[:, r]) == 0 or sum(data[:, r]) == len(data[:, r]):
291-
return np.empty(0)
292-
## *********** End ***********
293-
294-
no_of_var = data.shape[1]
295-
cg = CausalGraph(no_of_var)
296-
cg.data = data
297-
cg.set_ind_test(indep_test)
298-
cg.corr_mat = np.corrcoef(data, rowvar=False) if indep_test == fisherz else []
299-
300-
node_ids = range(no_of_var)
301-
pair_of_variables = list(permutations(node_ids, 2))
302-
303-
depth = -1
304-
while cg.max_degree() - 1 > depth:
305-
depth += 1
306-
edge_removal = []
307-
for (x, y) in pair_of_variables:
308-
309-
## *********** Adaptation 2 ***********
310-
# the skeleton search
311-
## Only test which variable is the neighbor of r
312-
if x != r:
313-
continue
314-
## *********** End ***********
315-
316-
Neigh_x = cg.neighbors(x)
317-
if y not in Neigh_x:
318-
continue
319-
else:
320-
Neigh_x = np.delete(Neigh_x, np.where(Neigh_x == y))
321-
322-
if len(Neigh_x) >= depth:
323-
for S in combinations(Neigh_x, depth):
324-
p = cg.ci_test(x, y, S)
325-
if p > alpha:
326-
if not stable: # Unstable: Remove x---y right away
327-
edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y])
328-
if edge1 is not None:
329-
cg.G.remove_edge(edge1)
330-
edge2 = cg.G.get_edge(cg.G.nodes[y], cg.G.nodes[x])
331-
if edge2 is not None:
332-
cg.G.remove_edge(edge2)
333-
else: # Stable: x---y will be removed only
334-
edge_removal.append((x, y)) # after all conditioning sets at
335-
edge_removal.append((y, x)) # depth l have been considered
336-
Helper.append_value(cg.sepset, x, y, S)
337-
Helper.append_value(cg.sepset, y, x, S)
338-
break
339-
340-
for (x, y) in list(set(edge_removal)):
341-
edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y])
342-
if edge1 is not None:
343-
cg.G.remove_edge(edge1)
344-
345-
## *********** Adaptation 3 ***********
346-
## extract the parent of r from the graph
347-
cg.to_nx_skeleton()
348-
cg_skel_adj: ndarray = nx.to_numpy_array(cg.nx_skel).astype(int)
349-
prt = get_parent(r, cg_skel_adj)
350-
## *********** End ***********
351-
352-
return prt
353-
354-
355-
def get_parent(r: int, cg_skel_adj: ndarray) -> ndarray:
356-
"""Get the neighbors of missingness indicators which are the parents
357-
:param r: the missingness indicator index
358-
:param cg_skel_adj: adjacancy matrix of a causal skeleton
359-
:return:
360-
prt: list, parents of the missingness indicator r
361-
"""
362-
num_var = len(cg_skel_adj[0, :])
363-
indx = np.array([i for i in range(num_var)])
364-
prt = indx[cg_skel_adj[r, :] == 1]
365-
return prt
366-
367-
368-
## *********** END ***********
369-
#######################################################################################################################
370-
371-
def skeleton_correction(data: ndarray, alpha: float, test_with_correction_name: str,
372-
init_cg: CausalGraph, prt_m: dict, stable: bool = True) -> CausalGraph:
373-
"""Perform skeleton discovery
374-
:param data: data set (numpy ndarray)
375-
:param alpha: desired significance level in (0, 1) (float)
376-
:param test_with_correction_name: name of the independence test being used
377-
- "MV_Crtn_Fisher_Z": Fisher's Z conditional independence test
378-
- "MV_Crtn_G_sq": G-squared conditional independence test
379-
:param stable: run stabilized skeleton discovery if True (default = True)
380-
:return:
381-
cg: a CausalGraph object
382-
"""
383-
384-
assert type(data) == np.ndarray
385-
assert 0 < alpha < 1
386-
assert test_with_correction_name in ["MV_Crtn_Fisher_Z", "MV_Crtn_G_sq"]
387-
388-
## *********** Adaption 1 ***********
389-
no_of_var = data.shape[1]
390-
391-
## Initialize the graph with the result of test-wise deletion skeletion search
392-
cg = init_cg
393-
394-
cg.data = data
395-
if test_with_correction_name in ["MV_Crtn_Fisher_Z", "MV_Crtn_G_sq"]:
396-
cg.set_ind_test(mc_fisherz, True)
397-
# No need of the correlation matrix if using test-wise deletion test
398-
cg.corr_mat = np.corrcoef(data, rowvar=False) if test_with_correction_name == "MV_Crtn_Fisher_Z" else []
399-
cg.prt_m = prt_m
400-
## *********** Adaption 1 ***********
401-
402-
node_ids = range(no_of_var)
403-
pair_of_variables = list(permutations(node_ids, 2))
404-
405-
depth = -1
406-
while cg.max_degree() - 1 > depth:
407-
depth += 1
408-
edge_removal = []
409-
for (x, y) in pair_of_variables:
410-
Neigh_x = cg.neighbors(x)
411-
if y not in Neigh_x:
412-
continue
413-
else:
414-
Neigh_x = np.delete(Neigh_x, np.where(Neigh_x == y))
415-
416-
if len(Neigh_x) >= depth:
417-
for S in combinations(Neigh_x, depth):
418-
p = cg.ci_test(x, y, S)
419-
if p > alpha:
420-
if not stable: # Unstable: Remove x---y right away
421-
edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y])
422-
if edge1 is not None:
423-
cg.G.remove_edge(edge1)
424-
edge2 = cg.G.get_edge(cg.G.nodes[y], cg.G.nodes[x])
425-
if edge2 is not None:
426-
cg.G.remove_edge(edge2)
427-
else: # Stable: x---y will be removed only
428-
edge_removal.append((x, y)) # after all conditioning sets at
429-
edge_removal.append((y, x)) # depth l have been considered
430-
Helper.append_value(cg.sepset, x, y, S)
431-
Helper.append_value(cg.sepset, y, x, S)
432-
break
433-
434-
for (x, y) in list(set(edge_removal)):
435-
edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y])
436-
if edge1 is not None:
437-
cg.G.remove_edge(edge1)
438-
439-
return cg
440-
441-
442-
#######################################################################################################################
443-
444-
# *********** Evaluation util ***********
445-
446-
def get_adjacancy_matrix(g: CausalGraph):
447-
return nx.to_numpy_array(g.nx_graph).astype(int)
448-
449-
450-
def matrix_diff(cg1: CausalGraph, cg2: CausalGraph):
451-
adj1 = get_adjacancy_matrix(cg1)
452-
adj2 = get_adjacancy_matrix(cg2)
453-
count = 0
454-
diff_ls = []
455-
for i in range(len(adj1[:, ])):
456-
for j in range(len(adj2[:, ])):
457-
if adj1[i, j] != adj2[i, j]:
458-
diff_ls.append((i, j))
459-
count += 1
460-
return count / 2, diff_ls

0 commit comments

Comments
 (0)