2
2
import os
3
3
import shutil , tempfile
4
4
import networkx as nx
5
- from causal_testing .specification .causal_dag import CausalDAG , close_separator , list_all_min_sep , CausalDAG
6
- from causal_testing .specification .optimised_causal_dag import CausalDAG as OptimisedCausalDAG
5
+ from causal_testing .specification .causal_dag import CausalDAG , close_separator , list_all_min_sep
7
6
from causal_testing .specification .scenario import Scenario
8
7
from causal_testing .specification .variable import Input , Output
9
8
from causal_testing .testing .base_test_case import BaseTestCase
@@ -26,7 +25,7 @@ def test_enumerate_minimal_adjustment_sets(self):
26
25
causal_dag = CausalDAG (self .dag_dot_path )
27
26
xs , ys = ["X" ], ["Y" ]
28
27
adjustment_sets = causal_dag .enumerate_minimal_adjustment_sets (xs , ys )
29
- self .assertEqual ([{"Z" }], adjustment_sets )
28
+ self .assertEqual ([{"Z" }], list ( adjustment_sets ) )
30
29
31
30
def tearDown (self ) -> None :
32
31
shutil .rmtree (self .temp_dir_path )
@@ -47,19 +46,19 @@ def test_valid_iv(self):
47
46
48
47
def test_unrelated_instrument (self ):
49
48
causal_dag = CausalDAG (self .dag_dot_path )
50
- causal_dag .graph . remove_edge ("I" , "X" )
49
+ causal_dag .remove_edge ("I" , "X" )
51
50
with self .assertRaises (ValueError ):
52
51
causal_dag .check_iv_assumptions ("X" , "Y" , "I" )
53
52
54
53
def test_direct_cause (self ):
55
54
causal_dag = CausalDAG (self .dag_dot_path )
56
- causal_dag .graph . add_edge ("I" , "Y" )
55
+ causal_dag .add_edge ("I" , "Y" )
57
56
with self .assertRaises (ValueError ):
58
57
causal_dag .check_iv_assumptions ("X" , "Y" , "I" )
59
58
60
59
def test_common_cause (self ):
61
60
causal_dag = CausalDAG (self .dag_dot_path )
62
- causal_dag .graph . add_edge ("U" , "I" )
61
+ causal_dag .add_edge ("U" , "I" )
63
62
with self .assertRaises (ValueError ):
64
63
causal_dag .check_iv_assumptions ("X" , "Y" , "I" )
65
64
@@ -280,12 +279,12 @@ def test_enumerate_minimal_adjustment_sets(self):
280
279
causal_dag = CausalDAG (self .dag_dot_path )
281
280
xs , ys = ["X1" , "X2" ], ["Y" ]
282
281
adjustment_sets = causal_dag .enumerate_minimal_adjustment_sets (xs , ys )
283
- self .assertEqual ([{"Z" }], adjustment_sets )
282
+ self .assertEqual ([{"Z" }], list ( adjustment_sets ) )
284
283
285
284
def test_enumerate_minimal_adjustment_sets_multiple (self ):
286
285
"""Test whether enumerate_minimal_adjustment_sets lists all minimum adjustment sets if multiple are possible."""
287
286
causal_dag = CausalDAG ()
288
- causal_dag .graph . add_edges_from (
287
+ causal_dag .add_edges_from (
289
288
[
290
289
("X1" , "X2" ),
291
290
("X2" , "V" ),
@@ -309,7 +308,7 @@ def test_enumerate_minimal_adjustment_sets_multiple(self):
309
308
def test_enumerate_minimal_adjustment_sets_two_adjustments (self ):
310
309
"""Test whether enumerate_minimal_adjustment_sets lists all possible minimum adjustment sets of arity two."""
311
310
causal_dag = CausalDAG ()
312
- causal_dag .graph . add_edges_from (
311
+ causal_dag .add_edges_from (
313
312
[
314
313
("X1" , "X2" ),
315
314
("X2" , "V" ),
@@ -336,7 +335,7 @@ def test_enumerate_minimal_adjustment_sets_two_adjustments(self):
336
335
def test_dag_with_non_character_nodes (self ):
337
336
"""Test identification for a DAG whose nodes are not just characters (strings of length greater than 1)."""
338
337
causal_dag = CausalDAG ()
339
- causal_dag .graph . add_edges_from (
338
+ causal_dag .add_edges_from (
340
339
[
341
340
("va" , "ba" ),
342
341
("ba" , "ia" ),
@@ -351,7 +350,7 @@ def test_dag_with_non_character_nodes(self):
351
350
)
352
351
xs , ys = ["ba" ], ["da" ]
353
352
adjustment_sets = causal_dag .enumerate_minimal_adjustment_sets (xs , ys )
354
- self .assertEqual (adjustment_sets , [{"aa" }, {"la" }, {"va" }])
353
+ self .assertEqual (list ( adjustment_sets ) , [{"aa" }, {"la" }, {"va" }])
355
354
356
355
def tearDown (self ) -> None :
357
356
shutil .rmtree (self .temp_dir_path )
@@ -485,148 +484,3 @@ def time_it(label, func, *args, **kwargs):
485
484
result = func (* args , ** kwargs )
486
485
print (f"{ label } took { time .time () - start :.6f} seconds" )
487
486
return result
488
-
489
-
490
- class TestOptimisedDAGIdentification (TestDAGIdentification ):
491
- """
492
- Test the Causal DAG identification algorithms and supporting algorithms.
493
- """
494
-
495
- def test_is_min_adjustment_for_not_min_adjustment (self ):
496
- """Test whether is_min_adjustment can correctly test whether the minimum adjustment set is not minimal."""
497
- causal_dag = CausalDAG (self .dag_dot_path )
498
- xs , ys , zs = ["X1" , "X2" ], ["Y" ], {"Z" , "V" }
499
-
500
- opt_dag = OptimisedCausalDAG (self .dag_dot_path )
501
-
502
- norm_result = time_it ("Norm" , lambda : causal_dag .adjustment_set_is_minimal (xs , ys , zs ))
503
- opt_result = time_it ("Opt" , lambda : opt_dag .adjustment_set_is_minimal (xs , ys , zs ))
504
- self .assertEqual (norm_result , opt_result )
505
-
506
- def test_is_min_adjustment_for_invalid_adjustment (self ):
507
- """Test whether is min_adjustment can correctly identify that the minimum adjustment set is invalid."""
508
- causal_dag = OptimisedCausalDAG (self .dag_dot_path )
509
- xs , ys , zs = ["X1" , "X2" ], ["Y" ], set ()
510
- self .assertRaises (ValueError , causal_dag .adjustment_set_is_minimal , xs , ys , zs )
511
-
512
- def test_get_ancestor_graph_of_causal_dag (self ):
513
- """Test whether get_ancestor_graph converts a CausalDAG to the correct ancestor graph."""
514
- causal_dag = OptimisedCausalDAG (self .dag_dot_path )
515
- xs , ys = ["X1" , "X2" ], ["Y" ]
516
- ancestor_graph = causal_dag .get_ancestor_graph (xs , ys )
517
- self .assertEqual (list (ancestor_graph .nodes ), ["X1" , "X2" , "D1" , "Y" , "Z" ])
518
- self .assertEqual (
519
- list (ancestor_graph .edges ),
520
- [("X1" , "X2" ), ("X2" , "D1" ), ("D1" , "Y" ), ("Z" , "X2" ), ("Z" , "Y" )],
521
- )
522
-
523
- def test_get_ancestor_graph_of_proper_backdoor_graph (self ):
524
- """Test whether get_ancestor_graph converts a CausalDAG to the correct proper back-door graph."""
525
- causal_dag = OptimisedCausalDAG (self .dag_dot_path )
526
- xs , ys = ["X1" , "X2" ], ["Y" ]
527
- proper_backdoor_graph = causal_dag .get_proper_backdoor_graph (xs , ys )
528
- ancestor_graph = proper_backdoor_graph .get_ancestor_graph (xs , ys )
529
- self .assertEqual (list (ancestor_graph .nodes ), ["X1" , "X2" , "D1" , "Y" , "Z" ])
530
- self .assertEqual (
531
- list (ancestor_graph .edges ),
532
- [("X1" , "X2" ), ("D1" , "Y" ), ("Z" , "X2" ), ("Z" , "Y" )],
533
- )
534
-
535
- def test_enumerate_minimal_adjustment_sets (self ):
536
- """Test whether enumerate_minimal_adjustment_sets lists all possible minimum sized adjustment sets."""
537
- causal_dag = OptimisedCausalDAG (self .dag_dot_path )
538
- xs , ys = ["X1" , "X2" ], ["Y" ]
539
- adjustment_sets = causal_dag .enumerate_minimal_adjustment_sets (xs , ys )
540
- self .assertEqual ([{"Z" }], list (adjustment_sets ))
541
-
542
- def test_enumerate_minimal_adjustment_sets_multiple (self ):
543
- """Test whether enumerate_minimal_adjustment_sets lists all minimum adjustment sets if multiple are possible."""
544
- causal_dag = OptimisedCausalDAG ()
545
- causal_dag .add_edges_from (
546
- [
547
- ("X1" , "X2" ),
548
- ("X2" , "V" ),
549
- ("Z1" , "X2" ),
550
- ("Z1" , "Z2" ),
551
- ("Z2" , "Z3" ),
552
- ("Z3" , "Y" ),
553
- ("D1" , "Y" ),
554
- ("D1" , "D2" ),
555
- ("Y" , "D3" ),
556
- ]
557
- )
558
- opt_causal_dag = OptimisedCausalDAG ()
559
- opt_causal_dag .add_edges_from (
560
- [
561
- ("X1" , "X2" ),
562
- ("X2" , "V" ),
563
- ("Z1" , "X2" ),
564
- ("Z1" , "Z2" ),
565
- ("Z2" , "Z3" ),
566
- ("Z3" , "Y" ),
567
- ("D1" , "Y" ),
568
- ("D1" , "D2" ),
569
- ("Y" , "D3" ),
570
- ]
571
- )
572
- xs , ys = ["X1" , "X2" ], ["Y" ]
573
-
574
- norm_adjustment_sets = time_it ("Norm" , lambda : causal_dag .enumerate_minimal_adjustment_sets (xs , ys ))
575
-
576
- opt_adjustment_sets = time_it ("Opt" , lambda : opt_causal_dag .enumerate_minimal_adjustment_sets (xs , ys ))
577
- set_of_opt_adjustment_sets = set (frozenset (min_separator ) for min_separator in opt_adjustment_sets )
578
-
579
- self .assertEqual (
580
- {frozenset ({"Z1" }), frozenset ({"Z2" }), frozenset ({"Z3" })},
581
- set_of_opt_adjustment_sets ,
582
- )
583
-
584
- def test_enumerate_minimal_adjustment_sets_two_adjustments (self ):
585
- """Test whether enumerate_minimal_adjustment_sets lists all possible minimum adjustment sets of arity two."""
586
- causal_dag = OptimisedCausalDAG ()
587
- causal_dag .add_edges_from (
588
- [
589
- ("X1" , "X2" ),
590
- ("X2" , "V" ),
591
- ("Z1" , "X2" ),
592
- ("Z1" , "Z2" ),
593
- ("Z2" , "Z3" ),
594
- ("Z3" , "Y" ),
595
- ("D1" , "Y" ),
596
- ("D1" , "D2" ),
597
- ("Y" , "D3" ),
598
- ("Z4" , "X1" ),
599
- ("Z4" , "Y" ),
600
- ("X2" , "D1" ),
601
- ]
602
- )
603
- xs , ys = ["X1" , "X2" ], ["Y" ]
604
- adjustment_sets = causal_dag .enumerate_minimal_adjustment_sets (xs , ys )
605
- set_of_adjustment_sets = set (frozenset (min_separator ) for min_separator in adjustment_sets )
606
- self .assertEqual (
607
- {frozenset ({"Z1" , "Z4" }), frozenset ({"Z2" , "Z4" }), frozenset ({"Z3" , "Z4" })},
608
- set_of_adjustment_sets ,
609
- )
610
-
611
- def test_dag_with_non_character_nodes (self ):
612
- """Test identification for a DAG whose nodes are not just characters (strings of length greater than 1)."""
613
- causal_dag = OptimisedCausalDAG ()
614
- causal_dag .add_edges_from (
615
- [
616
- ("va" , "ba" ),
617
- ("ba" , "ia" ),
618
- ("ba" , "da" ),
619
- ("ba" , "ra" ),
620
- ("la" , "va" ),
621
- ("la" , "aa" ),
622
- ("aa" , "ia" ),
623
- ("aa" , "da" ),
624
- ("aa" , "ra" ),
625
- ]
626
- )
627
- xs , ys = ["ba" ], ["da" ]
628
- adjustment_sets = causal_dag .enumerate_minimal_adjustment_sets (xs , ys )
629
- self .assertEqual (list (adjustment_sets ), [{"aa" }, {"la" }, {"va" }])
630
-
631
- def tearDown (self ) -> None :
632
- shutil .rmtree (self .temp_dir_path )
0 commit comments