@@ -50,9 +50,12 @@ def test_valid_causal_dag(self):
50
50
"""Test whether the Causal DAG is valid."""
51
51
causal_dag = CausalDAG (self .dag_dot_path )
52
52
print (causal_dag )
53
- assert list (causal_dag .graph .nodes ) == ["A" , "B" , "C" , "D" ] and list (
54
- causal_dag .graph .edges
55
- ) == [("A" , "B" ), ("B" , "C" ), ("D" , "A" ), ("D" , "C" )]
53
+ assert list (causal_dag .graph .nodes ) == ["A" , "B" , "C" , "D" ] and list (causal_dag .graph .edges ) == [
54
+ ("A" , "B" ),
55
+ ("B" , "C" ),
56
+ ("D" , "A" ),
57
+ ("D" , "C" ),
58
+ ]
56
59
57
60
def test_invalid_causal_dag (self ):
58
61
"""Test whether a cycle-containing directed graph is an invalid causal DAG."""
@@ -96,9 +99,7 @@ class TestDAGDirectEffectIdentification(unittest.TestCase):
96
99
def setUp (self ) -> None :
97
100
temp_dir_path = create_temp_dir_if_non_existent ()
98
101
self .dag_dot_path = os .path .join (temp_dir_path , "dag.dot" )
99
- dag_dot = (
100
- """digraph G { X1->X2;X2->V;X2->D1;X2->D2;D1->Y;D1->D2;Y->D3;Z->X2;Z->Y;}"""
101
- )
102
+ dag_dot = """digraph G { X1->X2;X2->V;X2->D1;X2->D2;D1->Y;D1->D2;Y->D3;Z->X2;Z->Y;}"""
102
103
f = open (self .dag_dot_path , "w" )
103
104
f .write (dag_dot )
104
105
f .close ()
@@ -122,9 +123,7 @@ class TestDAGIdentification(unittest.TestCase):
122
123
def setUp (self ) -> None :
123
124
temp_dir_path = create_temp_dir_if_non_existent ()
124
125
self .dag_dot_path = os .path .join (temp_dir_path , "dag.dot" )
125
- dag_dot = (
126
- """digraph G { X1->X2;X2->V;X2->D1;X2->D2;D1->Y;D1->D2;Y->D3;Z->X2;Z->Y;}"""
127
- )
126
+ dag_dot = """digraph G { X1->X2;X2->V;X2->D1;X2->D2;D1->Y;D1->D2;Y->D3;Z->X2;Z->Y;}"""
128
127
f = open (self .dag_dot_path , "w" )
129
128
f .write (dag_dot )
130
129
f .close ()
@@ -137,13 +136,10 @@ def test_get_indirect_graph(self):
137
136
self .assertEqual (list (indirect_graph .graph .edges ), original_edges )
138
137
self .assertEqual (indirect_graph .graph .nodes , causal_dag .graph .nodes )
139
138
140
-
141
139
def test_proper_backdoor_graph (self ):
142
140
"""Test whether converting a Causal DAG to a proper back-door graph works correctly."""
143
141
causal_dag = CausalDAG (self .dag_dot_path )
144
- proper_backdoor_graph = causal_dag .get_proper_backdoor_graph (
145
- ["X1" , "X2" ], ["Y" ]
146
- )
142
+ proper_backdoor_graph = causal_dag .get_proper_backdoor_graph (["X1" , "X2" ], ["Y" ])
147
143
self .assertEqual (
148
144
list (proper_backdoor_graph .graph .edges ),
149
145
[
@@ -163,11 +159,7 @@ def test_constructive_backdoor_criterion_should_hold(self):
163
159
causal_dag = CausalDAG (self .dag_dot_path )
164
160
xs , ys , zs = ["X1" , "X2" ], ["Y" ], ["Z" ]
165
161
proper_backdoor_graph = causal_dag .get_proper_backdoor_graph (xs , ys )
166
- self .assertTrue (
167
- causal_dag .constructive_backdoor_criterion (
168
- proper_backdoor_graph , xs , ys , zs
169
- )
170
- )
162
+ self .assertTrue (causal_dag .constructive_backdoor_criterion (proper_backdoor_graph , xs , ys , zs ))
171
163
172
164
def test_constructive_backdoor_criterion_should_not_hold_not_d_separator_in_proper_backdoor_graph (
173
165
self ,
@@ -176,11 +168,7 @@ def test_constructive_backdoor_criterion_should_not_hold_not_d_separator_in_prop
176
168
causal_dag = CausalDAG (self .dag_dot_path )
177
169
xs , ys , zs = ["X1" , "X2" ], ["Y" ], ["V" ]
178
170
proper_backdoor_graph = causal_dag .get_proper_backdoor_graph (xs , ys )
179
- self .assertFalse (
180
- causal_dag .constructive_backdoor_criterion (
181
- proper_backdoor_graph , xs , ys , zs
182
- )
183
- )
171
+ self .assertFalse (causal_dag .constructive_backdoor_criterion (proper_backdoor_graph , xs , ys , zs ))
184
172
185
173
def test_constructive_backdoor_criterion_should_not_hold_descendent_of_proper_causal_path (
186
174
self ,
@@ -190,11 +178,7 @@ def test_constructive_backdoor_criterion_should_not_hold_descendent_of_proper_ca
190
178
causal_dag = CausalDAG (self .dag_dot_path )
191
179
xs , ys , zs = ["X1" , "X2" ], ["Y" ], ["D1" ]
192
180
proper_backdoor_graph = causal_dag .get_proper_backdoor_graph (xs , ys )
193
- self .assertFalse (
194
- causal_dag .constructive_backdoor_criterion (
195
- proper_backdoor_graph , xs , ys , zs
196
- )
197
- )
181
+ self .assertFalse (causal_dag .constructive_backdoor_criterion (proper_backdoor_graph , xs , ys , zs ))
198
182
199
183
def test_is_min_adjustment_for_min_adjustment (self ):
200
184
"""Test whether is_min_adjustment can correctly test whether the minimum adjustment set is minimal."""
@@ -262,9 +246,7 @@ def test_enumerate_minimal_adjustment_sets_multiple(self):
262
246
)
263
247
xs , ys = ["X1" , "X2" ], ["Y" ]
264
248
adjustment_sets = causal_dag .enumerate_minimal_adjustment_sets (xs , ys )
265
- set_of_adjustment_sets = set (
266
- frozenset (min_separator ) for min_separator in adjustment_sets
267
- )
249
+ set_of_adjustment_sets = set (frozenset (min_separator ) for min_separator in adjustment_sets )
268
250
self .assertEqual (
269
251
{frozenset ({"Z1" }), frozenset ({"Z2" }), frozenset ({"Z3" })},
270
252
set_of_adjustment_sets ,
@@ -291,9 +273,7 @@ def test_enumerate_minimal_adjustment_sets_two_adjustments(self):
291
273
)
292
274
xs , ys = ["X1" , "X2" ], ["Y" ]
293
275
adjustment_sets = causal_dag .enumerate_minimal_adjustment_sets (xs , ys )
294
- set_of_adjustment_sets = set (
295
- frozenset (min_separator ) for min_separator in adjustment_sets
296
- )
276
+ set_of_adjustment_sets = set (frozenset (min_separator ) for min_separator in adjustment_sets )
297
277
self .assertEqual (
298
278
{frozenset ({"Z1" , "Z4" }), frozenset ({"Z2" , "Z4" }), frozenset ({"Z3" , "Z4" })},
299
279
set_of_adjustment_sets ,
@@ -304,20 +284,20 @@ def test_dag_with_non_character_nodes(self):
304
284
causal_dag = CausalDAG ()
305
285
causal_dag .graph .add_edges_from (
306
286
[
307
- ('va' , 'ba' ),
308
- ('ba' , 'ia' ),
309
- ('ba' , 'da' ),
310
- ('ba' , 'ra' ),
311
- ('la' , 'va' ),
312
- ('la' , 'aa' ),
313
- ('aa' , 'ia' ),
314
- ('aa' , 'da' ),
315
- ('aa' , 'ra' ),
287
+ ("va" , "ba" ),
288
+ ("ba" , "ia" ),
289
+ ("ba" , "da" ),
290
+ ("ba" , "ra" ),
291
+ ("la" , "va" ),
292
+ ("la" , "aa" ),
293
+ ("aa" , "ia" ),
294
+ ("aa" , "da" ),
295
+ ("aa" , "ra" ),
316
296
]
317
297
)
318
- xs , ys = ['ba' ], ['da' ]
298
+ xs , ys = ["ba" ], ["da" ]
319
299
adjustment_sets = causal_dag .enumerate_minimal_adjustment_sets (xs , ys )
320
- self .assertEqual (adjustment_sets , [{'aa' }, {'la' }, {'va' }])
300
+ self .assertEqual (adjustment_sets , [{"aa" }, {"la" }, {"va" }])
321
301
322
302
def tearDown (self ) -> None :
323
303
remove_temp_dir_if_existent ()
@@ -385,9 +365,7 @@ class TestUndirectedGraphAlgorithms(unittest.TestCase):
385
365
386
366
def setUp (self ) -> None :
387
367
self .graph = nx .Graph ()
388
- self .graph .add_edges_from (
389
- [("a" , 2 ), ("a" , 3 ), (2 , 4 ), (3 , 5 ), (3 , 4 ), (4 , "b" ), (5 , "b" )]
390
- )
368
+ self .graph .add_edges_from ([("a" , 2 ), ("a" , 3 ), (2 , 4 ), (3 , 5 ), (3 , 4 ), (4 , "b" ), (5 , "b" )])
391
369
self .treatment_node = "a"
392
370
self .outcome_node = "b"
393
371
self .treatment_node_set = {"a" }
@@ -396,9 +374,7 @@ def setUp(self) -> None:
396
374
397
375
def test_close_separator (self ):
398
376
"""Test whether close_separator correctly identifies the close separator of {2,3} in the undirected graph."""
399
- result = close_separator (
400
- self .graph , self .treatment_node , self .outcome_node , self .treatment_node_set
401
- )
377
+ result = close_separator (self .graph , self .treatment_node , self .outcome_node , self .treatment_node_set )
402
378
self .assertEqual ({2 , 3 }, result )
403
379
404
380
def test_list_all_min_sep (self ):
@@ -414,12 +390,8 @@ def test_list_all_min_sep(self):
414
390
)
415
391
416
392
# Convert list of sets to set of frozen sets for comparison
417
- min_separators = set (
418
- frozenset (min_separator ) for min_separator in min_separators
419
- )
420
- self .assertEqual (
421
- {frozenset ({2 , 3 }), frozenset ({3 , 4 }), frozenset ({4 , 5 })}, min_separators
422
- )
393
+ min_separators = set (frozenset (min_separator ) for min_separator in min_separators )
394
+ self .assertEqual ({frozenset ({2 , 3 }), frozenset ({3 , 4 }), frozenset ({4 , 5 })}, min_separators )
423
395
424
396
def tearDown (self ) -> None :
425
397
remove_temp_dir_if_existent ()
0 commit comments