1616from causallearn .utils .GraphUtils import GraphUtils
1717from causallearn .utils .PCUtils .BackgroundKnowledge import BackgroundKnowledge
1818
19- sys .path .append ("" )
20-
2119######################################### Test Notes ###########################################
2220# All the benchmark results of loaded files (e.g. "./TestData/benchmark_returned_results/") #
2321# are obtained from the code of causal-learn as of commit #
5452# verify files integrity first
5553for file_path , expected_MD5 in BENCHMARK_TXTFILE_TO_MD5 .items ():
5654 with open (file_path , 'rb' ) as fin :
57- assert hashlib .md5 (fin .read ()).hexdigest () == expected_MD5 ,\
55+ assert hashlib .md5 (fin .read ()).hexdigest () == expected_MD5 , \
5856 f'{ file_path } is corrupted. Please download it again from https://github.com/cmu-phil/causal-learn/blob/fb092d1/tests/TestData'
5957
6058
@@ -66,17 +64,16 @@ class TestFCI(unittest.TestCase):
6664 def test_simple_test (self ):
6765 data = np .empty (shape = (0 , 4 ))
6866 true_dag = DiGraph ()
69- true_dag .add_edges_from ([(0 , 1 ), (0 , 2 ), (1 , 3 ), (2 , 3 )])
67+ ground_truth_edges = [(0 , 1 ), (0 , 2 ), (1 , 3 ), (2 , 3 )]
68+ true_dag .add_edges_from (ground_truth_edges )
7069 G , edges = fci (data , d_separation , 0.05 , verbose = False , true_dag = true_dag )
7170
7271 ground_truth_nodes = []
7372 for i in range (4 ):
74- ground_truth_nodes .append (GraphNode (f'X{ i + 1 } ' ))
73+ ground_truth_nodes .append (GraphNode (f'X{ i + 1 } ' ))
7574 ground_truth_dag = Dag (ground_truth_nodes )
76- ground_truth_dag .add_directed_edge (ground_truth_nodes [0 ], ground_truth_nodes [1 ])
77- ground_truth_dag .add_directed_edge (ground_truth_nodes [0 ], ground_truth_nodes [2 ])
78- ground_truth_dag .add_directed_edge (ground_truth_nodes [1 ], ground_truth_nodes [3 ])
79- ground_truth_dag .add_directed_edge (ground_truth_nodes [2 ], ground_truth_nodes [3 ])
75+ for u , v in ground_truth_edges :
76+ ground_truth_dag .add_directed_edge (ground_truth_nodes [u ], ground_truth_nodes [v ])
8077 pag = dag2pag (ground_truth_dag , [])
8178
8279 print (f'fci(data, d_separation, 0.05):' )
@@ -86,88 +83,68 @@ def test_simple_test(self):
8683 assert G .is_adjacent_to (nodes [0 ], nodes [1 ])
8784
8885 bk = BackgroundKnowledge ().add_forbidden_by_node (nodes [0 ], nodes [1 ]).add_forbidden_by_node (nodes [1 ], nodes [0 ])
89- G_with_background_knowledge , edges = fci (data , d_separation , 0.05 , verbose = False , true_dag = true_dag , background_knowledge = bk )
86+ G_with_background_knowledge , edges = fci (data , d_separation , 0.05 , verbose = False , true_dag = true_dag ,
87+ background_knowledge = bk )
9088 assert not G_with_background_knowledge .is_adjacent_to (nodes [0 ], nodes [1 ])
9189
9290 def test_simple_test2 (self ):
9391 data = np .empty (shape = (0 , 7 ))
9492 true_dag = DiGraph ()
95- true_dag . add_edges_from ([( 'L1' , 0 ), ('L1' , 1 ), ('L2' , 3 ), ('L2' , 4 ), (2 , 5 ),
96- ( 2 , 6 ), ( 5 , 1 ), ( 6 , 3 ), ( 3 , 0 ), ( 1 , 4 )] )
93+ ground_truth_edges = [( 7 , 0 ), (7 , 1 ), (8 , 3 ), (8 , 4 ), (2 , 5 ), ( 2 , 6 ), ( 5 , 1 ), ( 6 , 3 ), ( 3 , 0 ), ( 1 , 4 )]
94+ true_dag . add_edges_from ( ground_truth_edges )
9795 G , edges = fci (data , d_separation , 0.05 , verbose = False , true_dag = true_dag )
9896 ground_truth_nodes = []
9997 for i in range (9 ):
100- ground_truth_nodes .append (GraphNode (f'X{ i + 1 } ' ))
98+ ground_truth_nodes .append (GraphNode (f'X{ i + 1 } ' ))
10199 ground_truth_dag = Dag (ground_truth_nodes )
102- ground_truth_dag .add_directed_edge (ground_truth_nodes [7 ], ground_truth_nodes [0 ])
103- ground_truth_dag .add_directed_edge (ground_truth_nodes [7 ], ground_truth_nodes [1 ])
104- ground_truth_dag .add_directed_edge (ground_truth_nodes [8 ], ground_truth_nodes [3 ])
105- ground_truth_dag .add_directed_edge (ground_truth_nodes [8 ], ground_truth_nodes [4 ])
106- ground_truth_dag .add_directed_edge (ground_truth_nodes [2 ], ground_truth_nodes [5 ])
107- ground_truth_dag .add_directed_edge (ground_truth_nodes [2 ], ground_truth_nodes [6 ])
108- ground_truth_dag .add_directed_edge (ground_truth_nodes [5 ], ground_truth_nodes [1 ])
109- ground_truth_dag .add_directed_edge (ground_truth_nodes [6 ], ground_truth_nodes [3 ])
110- ground_truth_dag .add_directed_edge (ground_truth_nodes [3 ], ground_truth_nodes [0 ])
111- ground_truth_dag .add_directed_edge (ground_truth_nodes [1 ], ground_truth_nodes [4 ])
100+ for u , v in ground_truth_edges :
101+ ground_truth_dag .add_directed_edge (ground_truth_nodes [u ], ground_truth_nodes [v ])
112102
113103 pag = dag2pag (ground_truth_dag , ground_truth_nodes [7 : 9 ])
114104
115105 print (f'fci(data, d_separation, 0.05):' )
116106 self .run_simulate_data_test (pag , G )
117107
118-
119108 def test_simple_test3 (self ):
120109
121110 data = np .empty (shape = (0 , 5 ))
122111 true_dag = DiGraph ()
123- true_dag .add_edges_from ([(0 , 2 ), (1 , 2 ), (2 , 3 ), (2 , 4 )])
112+ ground_truth_edges = [(0 , 2 ), (1 , 2 ), (2 , 3 ), (2 , 4 )]
113+ true_dag .add_edges_from (ground_truth_edges )
124114 G , edges = fci (data , d_separation , 0.05 , verbose = False , true_dag = true_dag )
125115
126116 ground_truth_nodes = []
127117 for i in range (5 ):
128118 ground_truth_nodes .append (GraphNode (f'X{ i + 1 } ' ))
129119 ground_truth_dag = Dag (ground_truth_nodes )
130- ground_truth_dag .add_directed_edge (ground_truth_nodes [0 ], ground_truth_nodes [2 ])
131- ground_truth_dag .add_directed_edge (ground_truth_nodes [1 ], ground_truth_nodes [2 ])
132- ground_truth_dag .add_directed_edge (ground_truth_nodes [2 ], ground_truth_nodes [3 ])
133- ground_truth_dag .add_directed_edge (ground_truth_nodes [2 ], ground_truth_nodes [4 ])
120+ for u , v in ground_truth_edges :
121+ ground_truth_dag .add_directed_edge (ground_truth_nodes [u ], ground_truth_nodes [v ])
134122
135123 pag = dag2pag (ground_truth_dag , [])
136124
137125 print (f'fci(data, d_separation, 0.05):' )
138126 self .run_simulate_data_test (pag , G )
139127
140-
141128 def test_fritl (self ):
142129 data = np .empty (shape = (0 , 7 ))
143130 true_dag = DiGraph ()
144- true_dag .add_edges_from ([('L1' , 0 ), ('L1' , 5 ), ('L2' , 0 ), ('L2' , 6 ), ('L3' , 3 ), ('L3' , 4 ), ('L3' , 6 ),
145- (0 , 1 ), (0 , 2 ), (1 , 2 ), (2 , 4 ), (5 , 6 )])
131+ ground_truth_edges = [(7 , 0 ), (7 , 5 ), (8 , 0 ), (8 , 6 ), (9 , 3 ), (9 , 4 ), (9 , 6 ),
132+ (0 , 1 ), (0 , 2 ), (1 , 2 ), (2 , 4 ), (5 , 6 )]
133+ true_dag .add_edges_from (ground_truth_edges )
146134 G , edges = fci (data , d_separation , 0.05 , verbose = False , true_dag = true_dag )
147135
148136 ground_truth_nodes = []
149137 for i in range (10 ):
150138 ground_truth_nodes .append (GraphNode (f'X{ i + 1 } ' ))
151139 ground_truth_dag = Dag (ground_truth_nodes )
152- ground_truth_dag .add_directed_edge (ground_truth_nodes [7 ], ground_truth_nodes [0 ])
153- ground_truth_dag .add_directed_edge (ground_truth_nodes [7 ], ground_truth_nodes [5 ])
154- ground_truth_dag .add_directed_edge (ground_truth_nodes [8 ], ground_truth_nodes [0 ])
155- ground_truth_dag .add_directed_edge (ground_truth_nodes [8 ], ground_truth_nodes [6 ])
156- ground_truth_dag .add_directed_edge (ground_truth_nodes [9 ], ground_truth_nodes [3 ])
157- ground_truth_dag .add_directed_edge (ground_truth_nodes [9 ], ground_truth_nodes [4 ])
158- ground_truth_dag .add_directed_edge (ground_truth_nodes [9 ], ground_truth_nodes [6 ])
159- ground_truth_dag .add_directed_edge (ground_truth_nodes [0 ], ground_truth_nodes [1 ])
160- ground_truth_dag .add_directed_edge (ground_truth_nodes [0 ], ground_truth_nodes [2 ])
161- ground_truth_dag .add_directed_edge (ground_truth_nodes [1 ], ground_truth_nodes [2 ])
162- ground_truth_dag .add_directed_edge (ground_truth_nodes [2 ], ground_truth_nodes [4 ])
163- ground_truth_dag .add_directed_edge (ground_truth_nodes [5 ], ground_truth_nodes [6 ])
140+ for u , v in ground_truth_edges :
141+ ground_truth_dag .add_directed_edge (ground_truth_nodes [u ], ground_truth_nodes [v ])
164142
165143 pag = dag2pag (ground_truth_dag , ground_truth_nodes [7 : 10 ])
166144
167145 print (f'fci(data, d_separation, 0.05):' )
168146 self .run_simulate_data_test (pag , G )
169147
170-
171148 @staticmethod
172149 def run_simulate_data_test (truth , est ):
173150 graph_utils = GraphUtils ()
0 commit comments