|
10 | 10 | from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge |
11 | 11 | from causallearn.utils.PCUtils.BackgroundKnowledgeOrientUtils import orient_by_background_knowledge |
12 | 12 | from causallearn.utils.cit import * |
| 13 | +from causallearn.search.ConstraintBased.PC import get_parent_missingness_pairs, skeleton_correction |
13 | 14 |
|
14 | 15 |
|
15 | | -def cdnod(data: ndarray, c_indx: ndarray, alpha: float = 0.05, indep_test=fisherz, stable: bool = True, |
16 | | - uc_rule: int = 0, uc_priority: int = 2, mvcdnod: bool = False, correction_name: str = 'MV_Crtn_Fisher_Z', |
17 | | - background_knowledge: Optional[BackgroundKnowledge] = None, verbose: bool = False, |
| 16 | +def cdnod(data: ndarray, c_indx: ndarray, alpha: float=0.05, indep_test: str=fisherz, stable: bool=True, |
| 17 | + uc_rule: int=0, uc_priority: int=2, mvcdnod: bool=False, correction_name: str='MV_Crtn_Fisher_Z', |
| 18 | + background_knowledge: Optional[BackgroundKnowledge]=None, verbose: bool=False, |
18 | 19 | show_progress: bool = True) -> CausalGraph: |
19 | 20 | """ |
20 | 21 | Causal discovery from nonstationary/heterogeneous data |
@@ -43,7 +44,7 @@ def cdnod(data: ndarray, c_indx: ndarray, alpha: float = 0.05, indep_test=fisher |
43 | 44 | show_progress=show_progress) |
44 | 45 |
|
45 | 46 |
|
46 | | -def cdnod_alg(data: ndarray, alpha: float, indep_test, stable: bool, uc_rule: int, uc_priority: int, |
| 47 | +def cdnod_alg(data: ndarray, alpha: float, indep_test: str, stable: bool, uc_rule: int, uc_priority: int, |
47 | 48 | background_knowledge: Optional[BackgroundKnowledge] = None, verbose: bool = False, |
48 | 49 | show_progress: bool = True) -> CausalGraph: |
49 | 50 | """ |
@@ -84,6 +85,7 @@ def cdnod_alg(data: ndarray, alpha: float, indep_test, stable: bool, uc_rule: in |
84 | 85 |
|
85 | 86 | """ |
86 | 87 | start = time.time() |
| 88 | + indep_test = CIT(data, indep_test) |
87 | 89 | cg_1 = SkeletonDiscovery.skeleton_discovery(data, alpha, indep_test, stable) |
88 | 90 |
|
89 | 91 | # orient the direction from c_indx to X, if there is an edge between c_indx and X |
@@ -124,7 +126,7 @@ def cdnod_alg(data: ndarray, alpha: float, indep_test, stable: bool, uc_rule: in |
124 | 126 | return cg |
125 | 127 |
|
126 | 128 |
|
127 | | -def mvcdnod_alg(data: ndarray, alpha: float, indep_test, correction_name: str, stable: bool, uc_rule: int, |
| 129 | +def mvcdnod_alg(data: ndarray, alpha: float, indep_test: str, correction_name: str, stable: bool, uc_rule: int, |
128 | 130 | uc_priority: int, verbose: bool, show_progress: bool) -> CausalGraph: |
129 | 131 | """ |
130 | 132 | :param data: data set (numpy ndarray) |
@@ -154,9 +156,9 @@ def mvcdnod_alg(data: ndarray, alpha: float, indep_test, correction_name: str, s |
154 | 156 | """ |
155 | 157 |
|
156 | 158 | start = time.time() |
157 | | - |
| 159 | + indep_test = CIT(data, indep_test) |
158 | 160 | ## Step 1: detect the direct causes of missingness indicators |
159 | | - prt_m = get_prt_mpairs(data, alpha, indep_test, stable) |
| 161 | + prt_m = get_parent_missingness_pairs(data, alpha, indep_test, stable) |
160 | 162 | # print('Finish detecting the parents of missingness indicators. ') |
161 | 163 |
|
162 | 164 | ## Step 2: |
@@ -204,257 +206,3 @@ def mvcdnod_alg(data: ndarray, alpha: float, indep_test, correction_name: str, s |
204 | 206 | cg.PC_elapsed = end - start |
205 | 207 |
|
206 | 208 | return cg |
207 | | - |
208 | | - |
209 | | -####################################################################################################################### |
210 | | -## *********** Functions for Step 1 *********** |
211 | | -def get_prt_mpairs(data: ndarray, alpha: float, indep_test, stable: bool = True) -> Dict[str, list]: |
212 | | - """ |
213 | | - Detect the parents of missingness indicators |
214 | | - If a missingness indicator has no parent, it will not be included in the result |
215 | | - :param data: data set (numpy ndarray) |
216 | | - :param alpha: desired significance level in (0, 1) (float) |
217 | | - :param indep_test: name of the test-wise deletion independence test being used |
218 | | - - "MV_Fisher_Z": Fisher's Z conditional independence test |
219 | | - - "MV_G_sq": G-squared conditional independence test (TODO: under development) |
220 | | - :param stable: run stabilized skeleton discovery if True (default = True) |
221 | | - :return: |
222 | | - cg: a CausalGraph object |
223 | | - """ |
224 | | - prt_m = {'prt': [], 'm': []} |
225 | | - |
226 | | - ## Get the index of missingness indicators |
227 | | - m_indx = get_mindx(data) |
228 | | - |
229 | | - ## Get the index of parents of missingness indicators |
230 | | - # If the missingness indicator has no parent, then it will not be collected in prt_m |
231 | | - for r in m_indx: |
232 | | - prt_r = detect_parent(r, data, alpha, indep_test, stable) |
233 | | - if isempty(prt_r): |
234 | | - pass |
235 | | - else: |
236 | | - prt_m['prt'].append(prt_r) |
237 | | - prt_m['m'].append(r) |
238 | | - return prt_m |
239 | | - |
240 | | - |
241 | | -def isempty(prt_r: ndarray) -> bool: |
242 | | - """Test whether the parent of a missingness indicator is empty""" |
243 | | - return len(prt_r) == 0 |
244 | | - |
245 | | - |
246 | | -def get_mindx(data: ndarray) -> List[int]: |
247 | | - """Detect the parents of missingness indicators |
248 | | - :param data: data set (numpy ndarray) |
249 | | - :return: |
250 | | - m_indx: list, the index of missingness indicators |
251 | | - """ |
252 | | - |
253 | | - m_indx = [] |
254 | | - _, ncol = np.shape(data) |
255 | | - for i in range(ncol): |
256 | | - if np.isnan(data[:, i]).any(): |
257 | | - m_indx.append(i) |
258 | | - return m_indx |
259 | | - |
260 | | - |
261 | | -def detect_parent(r: int, data_: ndarray, alpha: float, indep_test, stable: bool = True) -> ndarray: |
262 | | - """Detect the parents of a missingness indicator |
263 | | - :param r: the missingness indicator |
264 | | - :param data_: data set (numpy ndarray) |
265 | | - :param alpha: desired significance level in (0, 1) (float) |
266 | | - :param indep_test: name of the test-wise deletion independence test being used |
267 | | - - "MV_Fisher_Z": Fisher's Z conditional independence test |
268 | | - - "MV_G_sq": G-squared conditional independence test (TODO: under development) |
269 | | - :param stable: run stabilized skeleton discovery if True (default = True) |
270 | | - : return: |
271 | | - prt: parent of the missingness indicator, r |
272 | | - """ |
273 | | - ## TODO: in the test-wise deletion CI test, if test between a binary and a continuous variable, |
274 | | - # there can be the case where the binary variable only take one value after deletion. |
275 | | - # It is because the assumption is violated. |
276 | | - |
277 | | - ## *********** Adaptation 0 *********** |
278 | | - # For avoid changing the original data |
279 | | - data = data_.copy() |
280 | | - ## *********** End *********** |
281 | | - |
282 | | - assert type(data) == np.ndarray |
283 | | - assert 0 < alpha < 1 |
284 | | - |
285 | | - ## *********** Adaptation 1 *********** |
286 | | - # data |
287 | | - ## Replace the variable r with its missingness indicator |
288 | | - ## If r is not a missingness indicator, return []. |
289 | | - data[:, r] = np.isnan(data[:, r]).astype(float) # True is missing; false is not missing |
290 | | - if sum(data[:, r]) == 0 or sum(data[:, r]) == len(data[:, r]): |
291 | | - return np.empty(0) |
292 | | - ## *********** End *********** |
293 | | - |
294 | | - no_of_var = data.shape[1] |
295 | | - cg = CausalGraph(no_of_var) |
296 | | - cg.data = data |
297 | | - cg.set_ind_test(indep_test) |
298 | | - cg.corr_mat = np.corrcoef(data, rowvar=False) if indep_test == fisherz else [] |
299 | | - |
300 | | - node_ids = range(no_of_var) |
301 | | - pair_of_variables = list(permutations(node_ids, 2)) |
302 | | - |
303 | | - depth = -1 |
304 | | - while cg.max_degree() - 1 > depth: |
305 | | - depth += 1 |
306 | | - edge_removal = [] |
307 | | - for (x, y) in pair_of_variables: |
308 | | - |
309 | | - ## *********** Adaptation 2 *********** |
310 | | - # the skeleton search |
311 | | - ## Only test which variable is the neighbor of r |
312 | | - if x != r: |
313 | | - continue |
314 | | - ## *********** End *********** |
315 | | - |
316 | | - Neigh_x = cg.neighbors(x) |
317 | | - if y not in Neigh_x: |
318 | | - continue |
319 | | - else: |
320 | | - Neigh_x = np.delete(Neigh_x, np.where(Neigh_x == y)) |
321 | | - |
322 | | - if len(Neigh_x) >= depth: |
323 | | - for S in combinations(Neigh_x, depth): |
324 | | - p = cg.ci_test(x, y, S) |
325 | | - if p > alpha: |
326 | | - if not stable: # Unstable: Remove x---y right away |
327 | | - edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y]) |
328 | | - if edge1 is not None: |
329 | | - cg.G.remove_edge(edge1) |
330 | | - edge2 = cg.G.get_edge(cg.G.nodes[y], cg.G.nodes[x]) |
331 | | - if edge2 is not None: |
332 | | - cg.G.remove_edge(edge2) |
333 | | - else: # Stable: x---y will be removed only |
334 | | - edge_removal.append((x, y)) # after all conditioning sets at |
335 | | - edge_removal.append((y, x)) # depth l have been considered |
336 | | - Helper.append_value(cg.sepset, x, y, S) |
337 | | - Helper.append_value(cg.sepset, y, x, S) |
338 | | - break |
339 | | - |
340 | | - for (x, y) in list(set(edge_removal)): |
341 | | - edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y]) |
342 | | - if edge1 is not None: |
343 | | - cg.G.remove_edge(edge1) |
344 | | - |
345 | | - ## *********** Adaptation 3 *********** |
346 | | - ## extract the parent of r from the graph |
347 | | - cg.to_nx_skeleton() |
348 | | - cg_skel_adj: ndarray = nx.to_numpy_array(cg.nx_skel).astype(int) |
349 | | - prt = get_parent(r, cg_skel_adj) |
350 | | - ## *********** End *********** |
351 | | - |
352 | | - return prt |
353 | | - |
354 | | - |
355 | | -def get_parent(r: int, cg_skel_adj: ndarray) -> ndarray: |
356 | | - """Get the neighbors of missingness indicators which are the parents |
357 | | - :param r: the missingness indicator index |
358 | | - :param cg_skel_adj: adjacancy matrix of a causal skeleton |
359 | | - :return: |
360 | | - prt: list, parents of the missingness indicator r |
361 | | - """ |
362 | | - num_var = len(cg_skel_adj[0, :]) |
363 | | - indx = np.array([i for i in range(num_var)]) |
364 | | - prt = indx[cg_skel_adj[r, :] == 1] |
365 | | - return prt |
366 | | - |
367 | | - |
368 | | -## *********** END *********** |
369 | | -####################################################################################################################### |
370 | | - |
371 | | -def skeleton_correction(data: ndarray, alpha: float, test_with_correction_name: str, |
372 | | - init_cg: CausalGraph, prt_m: dict, stable: bool = True) -> CausalGraph: |
373 | | - """Perform skeleton discovery |
374 | | - :param data: data set (numpy ndarray) |
375 | | - :param alpha: desired significance level in (0, 1) (float) |
376 | | - :param test_with_correction_name: name of the independence test being used |
377 | | - - "MV_Crtn_Fisher_Z": Fisher's Z conditional independence test |
378 | | - - "MV_Crtn_G_sq": G-squared conditional independence test |
379 | | - :param stable: run stabilized skeleton discovery if True (default = True) |
380 | | - :return: |
381 | | - cg: a CausalGraph object |
382 | | - """ |
383 | | - |
384 | | - assert type(data) == np.ndarray |
385 | | - assert 0 < alpha < 1 |
386 | | - assert test_with_correction_name in ["MV_Crtn_Fisher_Z", "MV_Crtn_G_sq"] |
387 | | - |
388 | | - ## *********** Adaption 1 *********** |
389 | | - no_of_var = data.shape[1] |
390 | | - |
391 | | - ## Initialize the graph with the result of test-wise deletion skeletion search |
392 | | - cg = init_cg |
393 | | - |
394 | | - cg.data = data |
395 | | - if test_with_correction_name in ["MV_Crtn_Fisher_Z", "MV_Crtn_G_sq"]: |
396 | | - cg.set_ind_test(mc_fisherz, True) |
397 | | - # No need of the correlation matrix if using test-wise deletion test |
398 | | - cg.corr_mat = np.corrcoef(data, rowvar=False) if test_with_correction_name == "MV_Crtn_Fisher_Z" else [] |
399 | | - cg.prt_m = prt_m |
400 | | - ## *********** Adaption 1 *********** |
401 | | - |
402 | | - node_ids = range(no_of_var) |
403 | | - pair_of_variables = list(permutations(node_ids, 2)) |
404 | | - |
405 | | - depth = -1 |
406 | | - while cg.max_degree() - 1 > depth: |
407 | | - depth += 1 |
408 | | - edge_removal = [] |
409 | | - for (x, y) in pair_of_variables: |
410 | | - Neigh_x = cg.neighbors(x) |
411 | | - if y not in Neigh_x: |
412 | | - continue |
413 | | - else: |
414 | | - Neigh_x = np.delete(Neigh_x, np.where(Neigh_x == y)) |
415 | | - |
416 | | - if len(Neigh_x) >= depth: |
417 | | - for S in combinations(Neigh_x, depth): |
418 | | - p = cg.ci_test(x, y, S) |
419 | | - if p > alpha: |
420 | | - if not stable: # Unstable: Remove x---y right away |
421 | | - edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y]) |
422 | | - if edge1 is not None: |
423 | | - cg.G.remove_edge(edge1) |
424 | | - edge2 = cg.G.get_edge(cg.G.nodes[y], cg.G.nodes[x]) |
425 | | - if edge2 is not None: |
426 | | - cg.G.remove_edge(edge2) |
427 | | - else: # Stable: x---y will be removed only |
428 | | - edge_removal.append((x, y)) # after all conditioning sets at |
429 | | - edge_removal.append((y, x)) # depth l have been considered |
430 | | - Helper.append_value(cg.sepset, x, y, S) |
431 | | - Helper.append_value(cg.sepset, y, x, S) |
432 | | - break |
433 | | - |
434 | | - for (x, y) in list(set(edge_removal)): |
435 | | - edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y]) |
436 | | - if edge1 is not None: |
437 | | - cg.G.remove_edge(edge1) |
438 | | - |
439 | | - return cg |
440 | | - |
441 | | - |
442 | | -####################################################################################################################### |
443 | | - |
444 | | -# *********** Evaluation util *********** |
445 | | - |
446 | | -def get_adjacancy_matrix(g: CausalGraph): |
447 | | - return nx.to_numpy_array(g.nx_graph).astype(int) |
448 | | - |
449 | | - |
450 | | -def matrix_diff(cg1: CausalGraph, cg2: CausalGraph): |
451 | | - adj1 = get_adjacancy_matrix(cg1) |
452 | | - adj2 = get_adjacancy_matrix(cg2) |
453 | | - count = 0 |
454 | | - diff_ls = [] |
455 | | - for i in range(len(adj1[:, ])): |
456 | | - for j in range(len(adj2[:, ])): |
457 | | - if adj1[i, j] != adj2[i, j]: |
458 | | - diff_ls.append((i, j)) |
459 | | - count += 1 |
460 | | - return count / 2, diff_ls |
0 commit comments