Skip to content

Commit fe68f5b

Browse files
authored
Merge pull request #101 from MarcelRobeer/patch-1
Fixed bug in Fast Adjacency Search (FAS)
2 parents d101f9b + 93d0373 commit fe68f5b

File tree

8 files changed

+160
-75
lines changed

8 files changed

+160
-75
lines changed

causallearn/search/ConstraintBased/FCI.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from causallearn.utils.ChoiceGenerator import ChoiceGenerator
1414
from causallearn.utils.DepthChoiceGenerator import DepthChoiceGenerator
1515
from causallearn.utils.cit import *
16-
from causallearn.utils.Fas import fas
16+
from causallearn.utils.FAS import fas
1717
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
1818

1919

@@ -754,8 +754,8 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
754754
nodes.append(node)
755755

756756
# FAS (“Fast Adjacency Search”) is the adjacency search of the PC algorithm, used as a first step for the FCI algorithm.
757-
graph, sep_sets = fas(dataset, nodes, independence_test_method=independence_test_method, alpha=alpha,
758-
knowledge=background_knowledge, depth=depth, verbose=verbose)
757+
graph, sep_sets, test_results = fas(dataset, nodes, independence_test_method=independence_test_method, alpha=alpha,
758+
knowledge=background_knowledge, depth=depth, verbose=verbose)
759759

760760
reorientAllWith(graph, Endpoint.CIRCLE)
761761

Lines changed: 58 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,23 @@
11
from __future__ import annotations
22

3-
from copy import deepcopy
43
from itertools import combinations
54
from typing import List, Dict, Tuple, Set
65

76
from numpy import ndarray
87
from tqdm.auto import tqdm
98

10-
from causallearn.graph.Edges import Edges
119
from causallearn.graph.GeneralGraph import GeneralGraph
1210
from causallearn.graph.GraphClass import CausalGraph
1311
from causallearn.graph.Node import Node
14-
from causallearn.utils.ChoiceGenerator import ChoiceGenerator
1512
from causallearn.utils.PCUtils.Helper import append_value
1613
from causallearn.utils.cit import *
1714
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
1815

1916

20-
def fas(data: ndarray, nodes: List[Node], independence_test_method: CIT | None=None, alpha: float = 0.05,
17+
def fas(data: ndarray, nodes: List[Node], independence_test_method: CIT_Base, alpha: float = 0.05,
2118
knowledge: BackgroundKnowledge | None = None, depth: int = -1,
2219
verbose: bool = False, stable: bool = True, show_progress: bool = True) -> Tuple[
23-
GeneralGraph, Dict[Tuple[int, int], Set[int]]]:
20+
GeneralGraph, Dict[Tuple[int, int], Set[int]], Dict[Tuple[int, int, Set[int]], float]]:
2421
"""
2522
Implements the "fast adjacency search" used in several causal algorithm in this file. In the fast adjacency
2623
search, at a given stage of the search, an edge X*-*Y is removed from the graph if X _||_ Y | S, where S is a subset
@@ -50,78 +47,87 @@ def fas(data: ndarray, nodes: List[Node], independence_test_method: CIT | None=N
5047
Returns
5148
-------
5249
graph: Causal graph skeleton, where graph.graph[i,j] = graph.graph[j,i] = -1 indicates i --- j.
53-
sep_sets: separated sets of graph
50+
sep_sets: Separated sets of graph
51+
test_results: Results of conditional independence tests
5452
"""
55-
56-
assert type(data) == np.ndarray
57-
assert 0 < alpha < 1
53+
## ------- check parameters ------------
54+
if type(data) != np.ndarray:
55+
raise TypeError("'data' must be 'np.ndarray' type!")
56+
if not all(isinstance(node, Node) for node in nodes):
57+
raise TypeError("'nodes' must be 'List[Node]' type!")
58+
if not isinstance(independence_test_method, CIT_Base):
59+
raise TypeError("'independence_test_method' must be 'CIT_Base' type!")
60+
if type(alpha) != float or alpha <= 0 or alpha >= 1:
61+
raise TypeError("'alpha' must be 'float' type and between 0 and 1!")
62+
if knowledge is not None and type(knowledge) != BackgroundKnowledge:
63+
raise TypeError("'knowledge' must be 'BackgroundKnowledge' type!")
64+
if type(depth) != int or depth < -1:
65+
raise TypeError("'depth' must be 'int' type >= -1!")
66+
## ------- end check parameters ------------
67+
68+
if depth == -1:
69+
depth = float('inf')
5870

5971
no_of_var = data.shape[1]
6072
node_names = [node.get_name() for node in nodes]
6173
cg = CausalGraph(no_of_var, node_names)
6274
cg.set_ind_test(independence_test_method)
6375
sep_sets: Dict[Tuple[int, int], Set[int]] = {}
64-
65-
depth = -1
66-
pbar = tqdm(total=no_of_var) if show_progress else None
67-
while cg.max_degree() - 1 > depth:
68-
depth += 1
69-
edge_removal = []
70-
if show_progress:
71-
pbar.reset()
72-
for x in range(no_of_var):
76+
test_results: Dict[Tuple[int, int, Set[int]], float] = {}
77+
78+
def remove_if_exists(x: int, y: int) -> None:
79+
edge = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y])
80+
if edge is not None:
81+
cg.G.remove_edge(edge)
82+
83+
var_range = tqdm(range(no_of_var), leave=True) if show_progress \
84+
else range(no_of_var)
85+
current_depth: int = -1
86+
while cg.max_degree() - 1 > current_depth and current_depth < depth:
87+
current_depth += 1
88+
edge_removal = set()
89+
for x in var_range:
7390
if show_progress:
74-
pbar.update()
75-
if show_progress:
76-
pbar.set_description(f'Depth={depth}, working on node {x}')
91+
var_range.set_description(f'Depth={current_depth}, working on node {x}')
92+
var_range.update()
7793
Neigh_x = cg.neighbors(x)
78-
if len(Neigh_x) < depth - 1:
94+
if len(Neigh_x) < current_depth - 1:
7995
continue
8096
for y in Neigh_x:
81-
knowledge_ban_edge = False
8297
sepsets = set()
83-
if knowledge is not None and (
84-
knowledge.is_forbidden(cg.G.nodes[x], cg.G.nodes[y])
85-
and knowledge.is_forbidden(cg.G.nodes[y], cg.G.nodes[x])):
86-
knowledge_ban_edge = True
87-
if knowledge_ban_edge:
98+
if (knowledge is not None and
99+
knowledge.is_forbidden(cg.G.nodes[x], cg.G.nodes[y])
100+
and knowledge.is_forbidden(cg.G.nodes[y], cg.G.nodes[x])):
88101
if not stable:
89-
edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y])
90-
if edge1 is not None:
91-
cg.G.remove_edge(edge1)
92-
edge2 = cg.G.get_edge(cg.G.nodes[y], cg.G.nodes[x])
93-
if edge2 is not None:
94-
cg.G.remove_edge(edge2)
102+
remove_if_exists(x, y)
103+
remove_if_exists(y, x)
95104
append_value(cg.sepset, x, y, ())
96105
append_value(cg.sepset, y, x, ())
97106
sep_sets[(x, y)] = set()
98107
sep_sets[(y, x)] = set()
99108
break
100109
else:
101-
edge_removal.append((x, y)) # after all conditioning sets at
102-
edge_removal.append((y, x)) # depth l have been considered
110+
edge_removal.add((x, y)) # after all conditioning sets at
111+
edge_removal.add((y, x)) # depth l have been considered
103112

104113
Neigh_x_noy = np.delete(Neigh_x, np.where(Neigh_x == y))
105-
for S in combinations(Neigh_x_noy, depth):
114+
for S in combinations(Neigh_x_noy, current_depth):
106115
p = cg.ci_test(x, y, S)
116+
test_results[(x, y, S)] = p
107117
if p > alpha:
108118
if verbose:
109119
print('%d ind %d | %s with p-value %f\n' % (x, y, S, p))
110120
if not stable:
111-
edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y])
112-
if edge1 is not None:
113-
cg.G.remove_edge(edge1)
114-
edge2 = cg.G.get_edge(cg.G.nodes[y], cg.G.nodes[x])
115-
if edge2 is not None:
116-
cg.G.remove_edge(edge2)
121+
remove_if_exists(x, y)
122+
remove_if_exists(y, x)
117123
append_value(cg.sepset, x, y, S)
118124
append_value(cg.sepset, y, x, S)
119125
sep_sets[(x, y)] = set(S)
120126
sep_sets[(y, x)] = set(S)
121127
break
122128
else:
123-
edge_removal.append((x, y)) # after all conditioning sets at
124-
edge_removal.append((y, x)) # depth l have been considered
129+
edge_removal.add((x, y)) # after all conditioning sets at
130+
edge_removal.add((y, x)) # depth l have been considered
125131
for s in S:
126132
sepsets.add(s)
127133
else:
@@ -130,32 +136,12 @@ def fas(data: ndarray, nodes: List[Node], independence_test_method: CIT | None=N
130136
append_value(cg.sepset, x, y, tuple(sepsets))
131137
append_value(cg.sepset, y, x, tuple(sepsets))
132138

133-
if show_progress:
134-
pbar.refresh()
135-
136-
for (x, y) in list(set(edge_removal)):
137-
edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y])
138-
if edge1 is not None:
139-
cg.G.remove_edge(edge1)
139+
for (x, y) in edge_removal:
140+
remove_if_exists(x, y)
140141
if cg.sepset[x, y] is not None:
141-
origin_list = []
142-
for l_out in cg.sepset[x, y]:
143-
for l_in in l_out:
144-
origin_list.append(l_in)
145-
sep_sets[(x, y)] = set(origin_list)
146-
sep_sets[(y, x)] = set(origin_list)
147-
148-
149-
# for x in range(no_of_var):
150-
# for y in range(x, no_of_var):
151-
# if cg.sepset[x, y] is not None:
152-
# origin_list = []
153-
# for l_out in cg.sepset[x, y]:
154-
# for l_in in l_out:
155-
# origin_list.append(l_in)
156-
# sep_sets[(x, y)] = set(origin_list)
157-
158-
if show_progress:
159-
pbar.close()
142+
origin_set = set(l_in for l_out in cg.sepset[x, y]
143+
for l_in in l_out)
144+
sep_sets[(x, y)] = origin_set
145+
sep_sets[(y, x)] = origin_set
160146

161-
return cg.G, sep_sets
147+
return cg.G, sep_sets, test_results

tests/TestFAS.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import hashlib
2+
import os
3+
import random
4+
import unittest
5+
6+
import numpy as np
7+
8+
from causallearn.graph.GraphNode import GraphNode
9+
from causallearn.utils.cit import CIT, chisq, fisherz, kci, d_separation
10+
from causallearn.utils.FAS import fas
11+
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
12+
13+
BENCHMARK_TXTFILE_TO_MD5 = {
14+
"tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_asia_fci_chisq_0.05.txt": "65f54932a9d8224459e56c40129e6d8b",
15+
"tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_cancer_fci_chisq_0.05.txt": "0312381641cb3b4818e0c8539f74e802",
16+
"tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_earthquake_fci_chisq_0.05.txt": "a1160b92ce15a700858552f08e43b7de",
17+
"tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_sachs_fci_chisq_0.05.txt": "dced4a202fc32eceb75f53159fc81f3b",
18+
"tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_survey_fci_chisq_0.05.txt": "b1a28eee1e0c6ea8a64ac1624585c3f4",
19+
"tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_alarm_fci_chisq_0.05.txt": "c3bbc2b8aba456a4258dd071a42085bc",
20+
"tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_barley_fci_chisq_0.05.txt": "4a5000e7a582083859ee6aef15073676",
21+
"tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_child_fci_chisq_0.05.txt": "6b7858589e12f04b0f489ba4589a1254",
22+
"tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_insurance_fci_chisq_0.05.txt": "9975942b936aa2b1fc90c09318ca2d08",
23+
"tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_water_fci_chisq_0.05.txt": "48eee804d59526187b7ecd0519556ee5",
24+
"tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_hailfinder_fci_chisq_0.05.txt": "6b9a6b95b6474f8530e85c022f4e749c",
25+
"tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_hepar2_fci_chisq_0.05.txt": "4aae21ff3d9aa2435515ed2ee402294c",
26+
"tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_win95pts_fci_chisq_0.05.txt": "648fdf271e1440c06ca2b31b55ef1f3f",
27+
"tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_andes_fci_chisq_0.05.txt": "04092ae93e54c727579f08bf5dc34c77",
28+
"tests/TestData/benchmark_returned_results/linear_10_fci_fisherz_0.05.txt": "289c86f9c665bf82bbcc4c9e1dcec3e7"
29+
}
30+
31+
# verify files integrity first
32+
for file_path, expected_MD5 in BENCHMARK_TXTFILE_TO_MD5.items():
33+
with open(file_path, 'rb') as fin:
34+
assert hashlib.md5(fin.read()).hexdigest() == expected_MD5, \
35+
f'{file_path} is corrupted. Please download it again from https://github.com/cmu-phil/causal-learn/blob/5918419/tests/TestData'
36+
37+
38+
class TestFAS(unittest.TestCase):
39+
def test_inputs(self):
40+
data = np.loadtxt('tests/data_linear_10.txt', skiprows=1)
41+
alpha = 0.05
42+
cit = CIT(data, fisherz, alpha=alpha)
43+
nodes = [GraphNode(f"X{i + 1}") for i in range(data.shape[1])]
44+
bgk = BackgroundKnowledge()
45+
self.assertRaises(TypeError, fas, data=None, nodes=nodes, independence_test_method=cit, alpha=alpha, knowledge=bgk, verbose=False)
46+
self.assertRaises(TypeError, fas, data=data, nodes=None, independence_test_method=cit, alpha=alpha, knowledge=bgk, verbose=False)
47+
self.assertRaises(TypeError, fas, data=data, nodes=nodes, independence_test_method=None, alpha=alpha, knowledge=bgk, verbose=False)
48+
self.assertRaises(TypeError, fas, data=data, nodes=nodes, independence_test_method=cit, alpha=1, knowledge=bgk, verbose=False)
49+
self.assertRaises(TypeError, fas, data=data, nodes=nodes, independence_test_method=cit, alpha=0, knowledge=bgk, verbose=False)
50+
self.assertRaises(TypeError, fas, data=data, nodes=nodes, independence_test_method=cit, alpha=alpha, knowledge=data, verbose=False)
51+
52+
@staticmethod
53+
def run_test_with_random_background(data, cit, alpha):
54+
random.seed(42)
55+
56+
nodes = [GraphNode(f"X{i + 1}") for i in range(data.shape[1])]
57+
bgk = BackgroundKnowledge()
58+
for _ in range(5):
59+
node1, node2 = random.sample(nodes, 2)
60+
bgk.add_forbidden_by_node(node1, node2)
61+
bgk.add_forbidden_by_node(node2, node1)
62+
G, edges, test_results = fas(data, nodes, cit, alpha, knowledge=bgk, verbose=False)
63+
assert G.num_vars == data.shape[1], 'Graph should contain the same number of nodes as variables.'
64+
assert all(G.get_edge(x, y) is None for x, y in bgk.forbidden_rules_specs), 'Graph contains forbidden edges.'
65+
66+
@staticmethod
67+
def run_test_at_depths(data, cit, alpha):
68+
random.seed(42)
69+
70+
nodes = [GraphNode(f"X{i + 1}") for i in range(data.shape[1])]
71+
for _ in range(3):
72+
depth = random.randint(1, min(data.shape[1], 5))
73+
G, edges, test_results = fas(data, nodes, cit, alpha, depth=depth, verbose=False)
74+
assert max(len(S) for _, _, S in test_results.keys()) <= depth, 'Tests performed with depth greater than maximum depth.'
75+
76+
def test_bnlearn_discrete_datasets(self):
77+
benchmark_names = [
78+
"asia", "cancer", "earthquake", "sachs", "survey",
79+
"alarm", "barley", "child", "insurance", "water",
80+
"hailfinder", "hepar2", "win95pts",
81+
"andes"
82+
]
83+
84+
bnlearn_path = 'tests/TestData/bnlearn_discrete_10000/data'
85+
alpha = 0.05
86+
for bname in benchmark_names:
87+
print(f'Testing discrete dataset "{bname}...')
88+
data = np.loadtxt(os.path.join(bnlearn_path, f'{bname}.txt'), skiprows=1)
89+
cit = CIT(data, chisq, alpha=alpha)
90+
TestFAS.run_test_with_random_background(data, cit, alpha)
91+
TestFAS.run_test_at_depths(data, cit, alpha)
92+
93+
def test_continuous_dataset(self):
94+
print('Testing continuous dataset...')
95+
data = np.loadtxt('tests/data_linear_10.txt', skiprows=1)
96+
alpha = 0.05
97+
cit = CIT(data, fisherz, alpha=alpha)
98+
TestFAS.run_test_with_random_background(data, cit, alpha)
99+
TestFAS.run_test_at_depths(data, cit, alpha)
-6.13 KB
Binary file not shown.
-4.79 KB
Binary file not shown.
-22.4 KB
Binary file not shown.
-133 Bytes
Binary file not shown.
-3.85 KB
Binary file not shown.

0 commit comments

Comments
 (0)