11from __future__ import annotations
22
3- from copy import deepcopy
43from itertools import combinations
54from typing import List , Dict , Tuple , Set
65
76from numpy import ndarray
87from tqdm .auto import tqdm
98
10- from causallearn .graph .Edges import Edges
119from causallearn .graph .GeneralGraph import GeneralGraph
1210from causallearn .graph .GraphClass import CausalGraph
1311from causallearn .graph .Node import Node
14- from causallearn .utils .ChoiceGenerator import ChoiceGenerator
1512from causallearn .utils .PCUtils .Helper import append_value
1613from causallearn .utils .cit import *
1714from causallearn .utils .PCUtils .BackgroundKnowledge import BackgroundKnowledge
1815
1916
20- def fas (data : ndarray , nodes : List [Node ], independence_test_method : CIT | None = None , alpha : float = 0.05 ,
17+ def fas (data : ndarray , nodes : List [Node ], independence_test_method : CIT_Base , alpha : float = 0.05 ,
2118 knowledge : BackgroundKnowledge | None = None , depth : int = - 1 ,
2219 verbose : bool = False , stable : bool = True , show_progress : bool = True ) -> Tuple [
23- GeneralGraph , Dict [Tuple [int , int ], Set [int ]]]:
20+ GeneralGraph , Dict [Tuple [int , int ], Set [int ]], Dict [ Tuple [ int , int , Set [ int ]], float ] ]:
2421 """
2522 Implements the "fast adjacency search" used in several causal algorithm in this file. In the fast adjacency
2623 search, at a given stage of the search, an edge X*-*Y is removed from the graph if X _||_ Y | S, where S is a subset
@@ -50,78 +47,87 @@ def fas(data: ndarray, nodes: List[Node], independence_test_method: CIT | None=N
5047 Returns
5148 -------
5249 graph: Causal graph skeleton, where graph.graph[i,j] = graph.graph[j,i] = -1 indicates i --- j.
53- sep_sets: separated sets of graph
50+ sep_sets: Separated sets of graph
51+ test_results: Results of conditional independence tests
5452 """
55-
56- assert type (data ) == np .ndarray
57- assert 0 < alpha < 1
53+ ## ------- check parameters ------------
54+ if type (data ) != np .ndarray :
55+ raise TypeError ("'data' must be 'np.ndarray' type!" )
56+ if not all (isinstance (node , Node ) for node in nodes ):
57+ raise TypeError ("'nodes' must be 'List[Node]' type!" )
58+ if not isinstance (independence_test_method , CIT_Base ):
59+ raise TypeError ("'independence_test_method' must be 'CIT_Base' type!" )
60+ if type (alpha ) != float or alpha <= 0 or alpha >= 1 :
61+ raise TypeError ("'alpha' must be 'float' type and between 0 and 1!" )
62+ if knowledge is not None and type (knowledge ) != BackgroundKnowledge :
63+ raise TypeError ("'knowledge' must be 'BackgroundKnowledge' type!" )
64+ if type (depth ) != int or depth < - 1 :
65+ raise TypeError ("'depth' must be 'int' type >= -1!" )
66+ ## ------- end check parameters ------------
67+
68+ if depth == - 1 :
69+ depth = float ('inf' )
5870
5971 no_of_var = data .shape [1 ]
6072 node_names = [node .get_name () for node in nodes ]
6173 cg = CausalGraph (no_of_var , node_names )
6274 cg .set_ind_test (independence_test_method )
6375 sep_sets : Dict [Tuple [int , int ], Set [int ]] = {}
64-
65- depth = - 1
66- pbar = tqdm (total = no_of_var ) if show_progress else None
67- while cg .max_degree () - 1 > depth :
68- depth += 1
69- edge_removal = []
70- if show_progress :
71- pbar .reset ()
72- for x in range (no_of_var ):
76+ test_results : Dict [Tuple [int , int , Set [int ]], float ] = {}
77+
78+ def remove_if_exists (x : int , y : int ) -> None :
79+ edge = cg .G .get_edge (cg .G .nodes [x ], cg .G .nodes [y ])
80+ if edge is not None :
81+ cg .G .remove_edge (edge )
82+
83+ var_range = tqdm (range (no_of_var ), leave = True ) if show_progress \
84+ else range (no_of_var )
85+ current_depth : int = - 1
86+ while cg .max_degree () - 1 > current_depth and current_depth < depth :
87+ current_depth += 1
88+ edge_removal = set ()
89+ for x in var_range :
7390 if show_progress :
74- pbar .update ()
75- if show_progress :
76- pbar .set_description (f'Depth={ depth } , working on node { x } ' )
91+ var_range .set_description (f'Depth={ current_depth } , working on node { x } ' )
92+ var_range .update ()
7793 Neigh_x = cg .neighbors (x )
78- if len (Neigh_x ) < depth - 1 :
94+ if len (Neigh_x ) < current_depth - 1 :
7995 continue
8096 for y in Neigh_x :
81- knowledge_ban_edge = False
8297 sepsets = set ()
83- if knowledge is not None and (
84- knowledge .is_forbidden (cg .G .nodes [x ], cg .G .nodes [y ])
85- and knowledge .is_forbidden (cg .G .nodes [y ], cg .G .nodes [x ])):
86- knowledge_ban_edge = True
87- if knowledge_ban_edge :
98+ if (knowledge is not None and
99+ knowledge .is_forbidden (cg .G .nodes [x ], cg .G .nodes [y ])
100+ and knowledge .is_forbidden (cg .G .nodes [y ], cg .G .nodes [x ])):
88101 if not stable :
89- edge1 = cg .G .get_edge (cg .G .nodes [x ], cg .G .nodes [y ])
90- if edge1 is not None :
91- cg .G .remove_edge (edge1 )
92- edge2 = cg .G .get_edge (cg .G .nodes [y ], cg .G .nodes [x ])
93- if edge2 is not None :
94- cg .G .remove_edge (edge2 )
102+ remove_if_exists (x , y )
103+ remove_if_exists (y , x )
95104 append_value (cg .sepset , x , y , ())
96105 append_value (cg .sepset , y , x , ())
97106 sep_sets [(x , y )] = set ()
98107 sep_sets [(y , x )] = set ()
99108 break
100109 else :
101- edge_removal .append ((x , y )) # after all conditioning sets at
102- edge_removal .append ((y , x )) # depth l have been considered
110+ edge_removal .add ((x , y )) # after all conditioning sets at
111+ edge_removal .add ((y , x )) # depth l have been considered
103112
104113 Neigh_x_noy = np .delete (Neigh_x , np .where (Neigh_x == y ))
105- for S in combinations (Neigh_x_noy , depth ):
114+ for S in combinations (Neigh_x_noy , current_depth ):
106115 p = cg .ci_test (x , y , S )
116+ test_results [(x , y , S )] = p
107117 if p > alpha :
108118 if verbose :
109119 print ('%d ind %d | %s with p-value %f\n ' % (x , y , S , p ))
110120 if not stable :
111- edge1 = cg .G .get_edge (cg .G .nodes [x ], cg .G .nodes [y ])
112- if edge1 is not None :
113- cg .G .remove_edge (edge1 )
114- edge2 = cg .G .get_edge (cg .G .nodes [y ], cg .G .nodes [x ])
115- if edge2 is not None :
116- cg .G .remove_edge (edge2 )
121+ remove_if_exists (x , y )
122+ remove_if_exists (y , x )
117123 append_value (cg .sepset , x , y , S )
118124 append_value (cg .sepset , y , x , S )
119125 sep_sets [(x , y )] = set (S )
120126 sep_sets [(y , x )] = set (S )
121127 break
122128 else :
123- edge_removal .append ((x , y )) # after all conditioning sets at
124- edge_removal .append ((y , x )) # depth l have been considered
129+ edge_removal .add ((x , y )) # after all conditioning sets at
130+ edge_removal .add ((y , x )) # depth l have been considered
125131 for s in S :
126132 sepsets .add (s )
127133 else :
@@ -130,32 +136,12 @@ def fas(data: ndarray, nodes: List[Node], independence_test_method: CIT | None=N
130136 append_value (cg .sepset , x , y , tuple (sepsets ))
131137 append_value (cg .sepset , y , x , tuple (sepsets ))
132138
133- if show_progress :
134- pbar .refresh ()
135-
136- for (x , y ) in list (set (edge_removal )):
137- edge1 = cg .G .get_edge (cg .G .nodes [x ], cg .G .nodes [y ])
138- if edge1 is not None :
139- cg .G .remove_edge (edge1 )
139+ for (x , y ) in edge_removal :
140+ remove_if_exists (x , y )
140141 if cg .sepset [x , y ] is not None :
141- origin_list = []
142- for l_out in cg .sepset [x , y ]:
143- for l_in in l_out :
144- origin_list .append (l_in )
145- sep_sets [(x , y )] = set (origin_list )
146- sep_sets [(y , x )] = set (origin_list )
147-
148-
149- # for x in range(no_of_var):
150- # for y in range(x, no_of_var):
151- # if cg.sepset[x, y] is not None:
152- # origin_list = []
153- # for l_out in cg.sepset[x, y]:
154- # for l_in in l_out:
155- # origin_list.append(l_in)
156- # sep_sets[(x, y)] = set(origin_list)
157-
158- if show_progress :
159- pbar .close ()
142+ origin_set = set (l_in for l_out in cg .sepset [x , y ]
143+ for l_in in l_out )
144+ sep_sets [(x , y )] = origin_set
145+ sep_sets [(y , x )] = origin_set
160146
161- return cg .G , sep_sets
147+ return cg .G , sep_sets , test_results
0 commit comments