Skip to content

Commit f0ed760

Browse files
committed
Refactored unit tests for FCI
Signed-off-by: ZhiyiHuang <[email protected]>
1 parent 0ae6d6c commit f0ed760

File tree

1 file changed

+22
-45
lines changed

1 file changed

+22
-45
lines changed

tests/TestFCI.py

Lines changed: 22 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
from causallearn.utils.GraphUtils import GraphUtils
1717
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
1818

19-
sys.path.append("")
20-
2119
######################################### Test Notes ###########################################
2220
# All the benchmark results of loaded files (e.g. "./TestData/benchmark_returned_results/") #
2321
# are obtained from the code of causal-learn as of commit #
@@ -54,7 +52,7 @@
5452
# verify files integrity first
5553
for file_path, expected_MD5 in BENCHMARK_TXTFILE_TO_MD5.items():
5654
with open(file_path, 'rb') as fin:
57-
assert hashlib.md5(fin.read()).hexdigest() == expected_MD5,\
55+
assert hashlib.md5(fin.read()).hexdigest() == expected_MD5, \
5856
f'{file_path} is corrupted. Please download it again from https://github.com/cmu-phil/causal-learn/blob/fb092d1/tests/TestData'
5957

6058

@@ -66,17 +64,16 @@ class TestFCI(unittest.TestCase):
6664
def test_simple_test(self):
6765
data = np.empty(shape=(0, 4))
6866
true_dag = DiGraph()
69-
true_dag.add_edges_from([(0, 1), (0, 2), (1, 3), (2, 3)])
67+
ground_truth_edges = [(0, 1), (0, 2), (1, 3), (2, 3)]
68+
true_dag.add_edges_from(ground_truth_edges)
7069
G, edges = fci(data, d_separation, 0.05, verbose=False, true_dag=true_dag)
7170

7271
ground_truth_nodes = []
7372
for i in range(4):
74-
ground_truth_nodes.append(GraphNode(f'X{i+1}'))
73+
ground_truth_nodes.append(GraphNode(f'X{i + 1}'))
7574
ground_truth_dag = Dag(ground_truth_nodes)
76-
ground_truth_dag.add_directed_edge(ground_truth_nodes[0], ground_truth_nodes[1])
77-
ground_truth_dag.add_directed_edge(ground_truth_nodes[0], ground_truth_nodes[2])
78-
ground_truth_dag.add_directed_edge(ground_truth_nodes[1], ground_truth_nodes[3])
79-
ground_truth_dag.add_directed_edge(ground_truth_nodes[2], ground_truth_nodes[3])
75+
for u, v in ground_truth_edges:
76+
ground_truth_dag.add_directed_edge(ground_truth_nodes[u], ground_truth_nodes[v])
8077
pag = dag2pag(ground_truth_dag, [])
8178

8279
print(f'fci(data, d_separation, 0.05):')
@@ -86,88 +83,68 @@ def test_simple_test(self):
8683
assert G.is_adjacent_to(nodes[0], nodes[1])
8784

8885
bk = BackgroundKnowledge().add_forbidden_by_node(nodes[0], nodes[1]).add_forbidden_by_node(nodes[1], nodes[0])
89-
G_with_background_knowledge, edges = fci(data, d_separation, 0.05, verbose=False, true_dag=true_dag, background_knowledge=bk)
86+
G_with_background_knowledge, edges = fci(data, d_separation, 0.05, verbose=False, true_dag=true_dag,
87+
background_knowledge=bk)
9088
assert not G_with_background_knowledge.is_adjacent_to(nodes[0], nodes[1])
9189

9290
def test_simple_test2(self):
9391
data = np.empty(shape=(0, 7))
9492
true_dag = DiGraph()
95-
true_dag.add_edges_from([('L1', 0), ('L1', 1), ('L2', 3), ('L2', 4), (2, 5),
96-
(2, 6), (5, 1), (6, 3), (3, 0), (1, 4)])
93+
ground_truth_edges = [(7, 0), (7, 1), (8, 3), (8, 4), (2, 5), (2, 6), (5, 1), (6, 3), (3, 0), (1, 4)]
94+
true_dag.add_edges_from(ground_truth_edges)
9795
G, edges = fci(data, d_separation, 0.05, verbose=False, true_dag=true_dag)
9896
ground_truth_nodes = []
9997
for i in range(9):
100-
ground_truth_nodes.append(GraphNode(f'X{i+1}'))
98+
ground_truth_nodes.append(GraphNode(f'X{i + 1}'))
10199
ground_truth_dag = Dag(ground_truth_nodes)
102-
ground_truth_dag.add_directed_edge(ground_truth_nodes[7], ground_truth_nodes[0])
103-
ground_truth_dag.add_directed_edge(ground_truth_nodes[7], ground_truth_nodes[1])
104-
ground_truth_dag.add_directed_edge(ground_truth_nodes[8], ground_truth_nodes[3])
105-
ground_truth_dag.add_directed_edge(ground_truth_nodes[8], ground_truth_nodes[4])
106-
ground_truth_dag.add_directed_edge(ground_truth_nodes[2], ground_truth_nodes[5])
107-
ground_truth_dag.add_directed_edge(ground_truth_nodes[2], ground_truth_nodes[6])
108-
ground_truth_dag.add_directed_edge(ground_truth_nodes[5], ground_truth_nodes[1])
109-
ground_truth_dag.add_directed_edge(ground_truth_nodes[6], ground_truth_nodes[3])
110-
ground_truth_dag.add_directed_edge(ground_truth_nodes[3], ground_truth_nodes[0])
111-
ground_truth_dag.add_directed_edge(ground_truth_nodes[1], ground_truth_nodes[4])
100+
for u, v in ground_truth_edges:
101+
ground_truth_dag.add_directed_edge(ground_truth_nodes[u], ground_truth_nodes[v])
112102

113103
pag = dag2pag(ground_truth_dag, ground_truth_nodes[7: 9])
114104

115105
print(f'fci(data, d_separation, 0.05):')
116106
self.run_simulate_data_test(pag, G)
117107

118-
119108
def test_simple_test3(self):
120109

121110
data = np.empty(shape=(0, 5))
122111
true_dag = DiGraph()
123-
true_dag.add_edges_from([(0, 2), (1, 2), (2, 3), (2, 4)])
112+
ground_truth_edges = [(0, 2), (1, 2), (2, 3), (2, 4)]
113+
true_dag.add_edges_from(ground_truth_edges)
124114
G, edges = fci(data, d_separation, 0.05, verbose=False, true_dag=true_dag)
125115

126116
ground_truth_nodes = []
127117
for i in range(5):
128118
ground_truth_nodes.append(GraphNode(f'X{i + 1}'))
129119
ground_truth_dag = Dag(ground_truth_nodes)
130-
ground_truth_dag.add_directed_edge(ground_truth_nodes[0], ground_truth_nodes[2])
131-
ground_truth_dag.add_directed_edge(ground_truth_nodes[1], ground_truth_nodes[2])
132-
ground_truth_dag.add_directed_edge(ground_truth_nodes[2], ground_truth_nodes[3])
133-
ground_truth_dag.add_directed_edge(ground_truth_nodes[2], ground_truth_nodes[4])
120+
for u, v in ground_truth_edges:
121+
ground_truth_dag.add_directed_edge(ground_truth_nodes[u], ground_truth_nodes[v])
134122

135123
pag = dag2pag(ground_truth_dag, [])
136124

137125
print(f'fci(data, d_separation, 0.05):')
138126
self.run_simulate_data_test(pag, G)
139127

140-
141128
def test_fritl(self):
142129
data = np.empty(shape=(0, 7))
143130
true_dag = DiGraph()
144-
true_dag.add_edges_from([('L1', 0), ('L1', 5), ('L2', 0), ('L2', 6), ('L3', 3), ('L3', 4), ('L3', 6),
145-
(0, 1), (0, 2), (1, 2), (2, 4), (5, 6)])
131+
ground_truth_edges = [(7, 0), (7, 5), (8, 0), (8, 6), (9, 3), (9, 4), (9, 6),
132+
(0, 1), (0, 2), (1, 2), (2, 4), (5, 6)]
133+
true_dag.add_edges_from(ground_truth_edges)
146134
G, edges = fci(data, d_separation, 0.05, verbose=False, true_dag=true_dag)
147135

148136
ground_truth_nodes = []
149137
for i in range(10):
150138
ground_truth_nodes.append(GraphNode(f'X{i + 1}'))
151139
ground_truth_dag = Dag(ground_truth_nodes)
152-
ground_truth_dag.add_directed_edge(ground_truth_nodes[7], ground_truth_nodes[0])
153-
ground_truth_dag.add_directed_edge(ground_truth_nodes[7], ground_truth_nodes[5])
154-
ground_truth_dag.add_directed_edge(ground_truth_nodes[8], ground_truth_nodes[0])
155-
ground_truth_dag.add_directed_edge(ground_truth_nodes[8], ground_truth_nodes[6])
156-
ground_truth_dag.add_directed_edge(ground_truth_nodes[9], ground_truth_nodes[3])
157-
ground_truth_dag.add_directed_edge(ground_truth_nodes[9], ground_truth_nodes[4])
158-
ground_truth_dag.add_directed_edge(ground_truth_nodes[9], ground_truth_nodes[6])
159-
ground_truth_dag.add_directed_edge(ground_truth_nodes[0], ground_truth_nodes[1])
160-
ground_truth_dag.add_directed_edge(ground_truth_nodes[0], ground_truth_nodes[2])
161-
ground_truth_dag.add_directed_edge(ground_truth_nodes[1], ground_truth_nodes[2])
162-
ground_truth_dag.add_directed_edge(ground_truth_nodes[2], ground_truth_nodes[4])
163-
ground_truth_dag.add_directed_edge(ground_truth_nodes[5], ground_truth_nodes[6])
140+
for u, v in ground_truth_edges:
141+
ground_truth_dag.add_directed_edge(ground_truth_nodes[u], ground_truth_nodes[v])
164142

165143
pag = dag2pag(ground_truth_dag, ground_truth_nodes[7: 10])
166144

167145
print(f'fci(data, d_separation, 0.05):')
168146
self.run_simulate_data_test(pag, G)
169147

170-
171148
@staticmethod
172149
def run_simulate_data_test(truth, est):
173150
graph_utils = GraphUtils()

0 commit comments

Comments
 (0)