|
2 | 2 | import os
|
3 | 3 | import shutil, tempfile
|
4 | 4 | import networkx as nx
|
5 |
| -from causal_testing.specification.causal_dag import CausalDAG, close_separator, list_all_min_sep |
| 5 | +from causal_testing.specification.causal_dag import CausalDAG, close_separator, list_all_min_sep, OptimisedCausalDAG |
6 | 6 | from causal_testing.specification.scenario import Scenario
|
7 | 7 | from causal_testing.specification.variable import Input, Output
|
8 | 8 | from causal_testing.testing.base_test_case import BaseTestCase
|
@@ -475,3 +475,166 @@ def test_hidden_varaible_adjustment_sets(self):
|
475 | 475 |
|
476 | 476 | def tearDown(self) -> None:
|
477 | 477 | shutil.rmtree(self.temp_dir_path)
|
| 478 | + |
| 479 | +def time_it(label, func, *args, **kwargs): |
| 480 | + import time |
| 481 | + start = time.time() |
| 482 | + result = func(*args, **kwargs) |
| 483 | + print(f"{label} took {time.time() - start:.6f} seconds") |
| 484 | + return result |
| 485 | + |
| 486 | +class TestOptimisedDAGIdentification(TestDAGIdentification): |
| 487 | + """ |
| 488 | + Test the Causal DAG identification algorithms and supporting algorithms. |
| 489 | + """ |
| 490 | + |
| 491 | + def test_is_min_adjustment_for_not_min_adjustment(self): |
| 492 | + """Test whether is_min_adjustment can correctly test whether the minimum adjustment set is not minimal.""" |
| 493 | + causal_dag = CausalDAG(self.dag_dot_path) |
| 494 | + xs, ys, zs = ["X1", "X2"], ["Y"], {"Z", "V"} |
| 495 | + |
| 496 | + opt_dag = OptimisedCausalDAG(self.dag_dot_path) |
| 497 | + |
| 498 | + norm_result = time_it( |
| 499 | + "Norm", |
| 500 | + lambda: causal_dag.adjustment_set_is_minimal(xs, ys, zs) |
| 501 | + ) |
| 502 | + opt_result = time_it( |
| 503 | + "Opt", |
| 504 | + lambda: opt_dag.adjustment_set_is_minimal(xs, ys, zs) |
| 505 | + ) |
| 506 | + self.assertEqual(norm_result, opt_result) |
| 507 | + |
| 508 | + def test_is_min_adjustment_for_invalid_adjustment(self): |
| 509 | + """Test whether is min_adjustment can correctly identify that the minimum adjustment set is invalid.""" |
| 510 | + causal_dag = OptimisedCausalDAG(self.dag_dot_path) |
| 511 | + xs, ys, zs = ["X1", "X2"], ["Y"], set() |
| 512 | + self.assertRaises(ValueError, causal_dag.adjustment_set_is_minimal, xs, ys, zs) |
| 513 | + |
| 514 | + def test_get_ancestor_graph_of_causal_dag(self): |
| 515 | + """Test whether get_ancestor_graph converts a CausalDAG to the correct ancestor graph.""" |
| 516 | + causal_dag = OptimisedCausalDAG(self.dag_dot_path) |
| 517 | + xs, ys = ["X1", "X2"], ["Y"] |
| 518 | + ancestor_graph = causal_dag.get_ancestor_graph(xs, ys) |
| 519 | + self.assertEqual(list(ancestor_graph.nodes), ["X1", "X2", "D1", "Y", "Z"]) |
| 520 | + self.assertEqual( |
| 521 | + list(ancestor_graph.edges), |
| 522 | + [("X1", "X2"), ("X2", "D1"), ("D1", "Y"), ("Z", "X2"), ("Z", "Y")], |
| 523 | + ) |
| 524 | + |
| 525 | + def test_get_ancestor_graph_of_proper_backdoor_graph(self): |
| 526 | + """Test whether get_ancestor_graph converts a CausalDAG to the correct proper back-door graph.""" |
| 527 | + causal_dag = OptimisedCausalDAG(self.dag_dot_path) |
| 528 | + xs, ys = ["X1", "X2"], ["Y"] |
| 529 | + proper_backdoor_graph = causal_dag.get_proper_backdoor_graph(xs, ys) |
| 530 | + ancestor_graph = proper_backdoor_graph.get_ancestor_graph(xs, ys) |
| 531 | + self.assertEqual(list(ancestor_graph.nodes), ["X1", "X2", "D1", "Y", "Z"]) |
| 532 | + self.assertEqual( |
| 533 | + list(ancestor_graph.edges), |
| 534 | + [("X1", "X2"), ("D1", "Y"), ("Z", "X2"), ("Z", "Y")], |
| 535 | + ) |
| 536 | + |
| 537 | + def test_enumerate_minimal_adjustment_sets(self): |
| 538 | + """Test whether enumerate_minimal_adjustment_sets lists all possible minimum sized adjustment sets.""" |
| 539 | + causal_dag = OptimisedCausalDAG(self.dag_dot_path) |
| 540 | + xs, ys = ["X1", "X2"], ["Y"] |
| 541 | + adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys) |
| 542 | + self.assertEqual([{"Z"}], adjustment_sets) |
| 543 | + |
| 544 | + def test_enumerate_minimal_adjustment_sets_multiple(self): |
| 545 | + """Test whether enumerate_minimal_adjustment_sets lists all minimum adjustment sets if multiple are possible.""" |
| 546 | + causal_dag = CausalDAG() |
| 547 | + causal_dag.graph.add_edges_from( |
| 548 | + [ |
| 549 | + ("X1", "X2"), |
| 550 | + ("X2", "V"), |
| 551 | + ("Z1", "X2"), |
| 552 | + ("Z1", "Z2"), |
| 553 | + ("Z2", "Z3"), |
| 554 | + ("Z3", "Y"), |
| 555 | + ("D1", "Y"), |
| 556 | + ("D1", "D2"), |
| 557 | + ("Y", "D3"), |
| 558 | + ] |
| 559 | + ) |
| 560 | + opt_causal_dag = CausalDAG() |
| 561 | + opt_causal_dag.graph.add_edges_from( |
| 562 | + [ |
| 563 | + ("X1", "X2"), |
| 564 | + ("X2", "V"), |
| 565 | + ("Z1", "X2"), |
| 566 | + ("Z1", "Z2"), |
| 567 | + ("Z2", "Z3"), |
| 568 | + ("Z3", "Y"), |
| 569 | + ("D1", "Y"), |
| 570 | + ("D1", "D2"), |
| 571 | + ("Y", "D3"), |
| 572 | + ] |
| 573 | + ) |
| 574 | + xs, ys = ["X1", "X2"], ["Y"] |
| 575 | + |
| 576 | + norm_adjustment_sets = time_it( |
| 577 | + "Norm", |
| 578 | + lambda: causal_dag.enumerate_minimal_adjustment_sets(xs, ys) |
| 579 | + ) |
| 580 | + |
| 581 | + opt_adjustment_sets = time_it( |
| 582 | + "Opt", |
| 583 | + lambda: opt_causal_dag.enumerate_minimal_adjustment_sets(xs, ys) |
| 584 | + ) |
| 585 | + set_of_opt_adjustment_sets = set(frozenset(min_separator) for min_separator in opt_adjustment_sets) |
| 586 | + |
| 587 | + self.assertEqual( |
| 588 | + {frozenset({"Z1"}), frozenset({"Z2"}), frozenset({"Z3"})}, |
| 589 | + set_of_opt_adjustment_sets, |
| 590 | + ) |
| 591 | + |
| 592 | + def test_enumerate_minimal_adjustment_sets_two_adjustments(self): |
| 593 | + """Test whether enumerate_minimal_adjustment_sets lists all possible minimum adjustment sets of arity two.""" |
| 594 | + causal_dag = OptimisedCausalDAG() |
| 595 | + causal_dag.graph.add_edges_from( |
| 596 | + [ |
| 597 | + ("X1", "X2"), |
| 598 | + ("X2", "V"), |
| 599 | + ("Z1", "X2"), |
| 600 | + ("Z1", "Z2"), |
| 601 | + ("Z2", "Z3"), |
| 602 | + ("Z3", "Y"), |
| 603 | + ("D1", "Y"), |
| 604 | + ("D1", "D2"), |
| 605 | + ("Y", "D3"), |
| 606 | + ("Z4", "X1"), |
| 607 | + ("Z4", "Y"), |
| 608 | + ("X2", "D1"), |
| 609 | + ] |
| 610 | + ) |
| 611 | + xs, ys = ["X1", "X2"], ["Y"] |
| 612 | + adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys) |
| 613 | + set_of_adjustment_sets = set(frozenset(min_separator) for min_separator in adjustment_sets) |
| 614 | + self.assertEqual( |
| 615 | + {frozenset({"Z1", "Z4"}), frozenset({"Z2", "Z4"}), frozenset({"Z3", "Z4"})}, |
| 616 | + set_of_adjustment_sets, |
| 617 | + ) |
| 618 | + |
| 619 | + def test_dag_with_non_character_nodes(self): |
| 620 | + """Test identification for a DAG whose nodes are not just characters (strings of length greater than 1).""" |
| 621 | + causal_dag = OptimisedCausalDAG() |
| 622 | + causal_dag.graph.add_edges_from( |
| 623 | + [ |
| 624 | + ("va", "ba"), |
| 625 | + ("ba", "ia"), |
| 626 | + ("ba", "da"), |
| 627 | + ("ba", "ra"), |
| 628 | + ("la", "va"), |
| 629 | + ("la", "aa"), |
| 630 | + ("aa", "ia"), |
| 631 | + ("aa", "da"), |
| 632 | + ("aa", "ra"), |
| 633 | + ] |
| 634 | + ) |
| 635 | + xs, ys = ["ba"], ["da"] |
| 636 | + adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys) |
| 637 | + self.assertEqual(adjustment_sets, [{"aa"}, {"la"}, {"va"}]) |
| 638 | + |
| 639 | + def tearDown(self) -> None: |
| 640 | + shutil.rmtree(self.temp_dir_path) |
0 commit comments