Skip to content

Commit 64596f6

Browse files
committed
Integrated optimised CausalDAG class
1 parent edc8455 commit 64596f6

File tree

8 files changed

+184
-919
lines changed

8 files changed

+184
-919
lines changed

causal_testing/main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def load_dag(self) -> CausalDAG:
131131
"""
132132
logger.info(f"Loading DAG from {self.paths.dag_path}")
133133
dag = CausalDAG(str(self.paths.dag_path), ignore_cycles=self.ignore_cycles)
134-
logger.info(f"DAG loaded with {len(dag.graph.nodes)} nodes and {len(dag.graph.edges)} edges")
134+
logger.info(f"DAG loaded with {len(dag.nodes)} nodes and {len(dag.edges)} edges")
135135
return dag
136136

137137
def _read_dataframe(self, data_path):
@@ -163,18 +163,18 @@ def create_variables(self) -> None:
163163
"""
164164
Create variable objects from DAG nodes based on their connectivity.
165165
"""
166-
for node_name, node_data in self.dag.graph.nodes(data=True):
166+
for node_name, node_data in self.dag.nodes(data=True):
167167
if node_name not in self.data.columns and not node_data.get("hidden", False):
168168
raise ValueError(f"Node {node_name} missing from data. Should it be marked as hidden?")
169169

170170
dtype = self.data.dtypes.get(node_name)
171171

172172
# If node has no incoming edges, it's an input
173-
if self.dag.graph.in_degree(node_name) == 0:
173+
if self.dag.in_degree(node_name) == 0:
174174
self.variables["inputs"][node_name] = Input(name=node_name, datatype=dtype)
175175

176176
# Otherwise it's an output
177-
if self.dag.graph.in_degree(node_name) > 0:
177+
if self.dag.in_degree(node_name) > 0:
178178
self.variables["outputs"][node_name] = Output(name=node_name, datatype=dtype)
179179

180180
def create_scenario_and_specification(self) -> None:

causal_testing/specification/causal_dag.py

Lines changed: 160 additions & 204 deletions
Large diffs are not rendered by default.

causal_testing/specification/optimised_causal_dag.py

Lines changed: 0 additions & 543 deletions
This file was deleted.

causal_testing/surrogate/causal_surrogate_assisted.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def generate_surrogates(
125125
surrogate_models = []
126126

127127
for u, v in specification.causal_dag.edges:
128-
edge_metadata = specification.causal_dag.graph.adj[u][v]
128+
edge_metadata = specification.causal_dag.adj[u][v]
129129
if "included" in edge_metadata:
130130
from_var = specification.scenario.variables.get(u)
131131
to_var = specification.scenario.variables.get(v)

causal_testing/testing/metamorphic_relation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,13 @@ def generate_metamorphic_relation(
109109
# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
110110
if ((u, v) not in dag.edges) and ((v, u) not in dag.edges):
111111
# Case 1: U --> ... --> V
112-
if u in nx.ancestors(dag.graph, v):
112+
if u in nx.ancestors(dag, v):
113113
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
114114
if adj_sets:
115115
metamorphic_relations.append(ShouldNotCause(BaseTestCase(u, v), list(adj_sets[0])))
116116

117117
# Case 2: V --> ... --> U
118-
elif v in nx.ancestors(dag.graph, u):
118+
elif v in nx.ancestors(dag, u):
119119
adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore)
120120
if adj_sets:
121121
metamorphic_relations.append(ShouldNotCause(BaseTestCase(v, u), list(adj_sets[0])))
@@ -194,7 +194,7 @@ def generate_causal_tests(dag_path: str, output_path: str, ignore_cycles: bool =
194194
causal_dag = CausalDAG(dag_path, ignore_cycles=ignore_cycles)
195195

196196
dag_nodes_to_test = [
197-
node for node in causal_dag.nodes if nx.get_node_attributes(causal_dag.graph, "test", default=True)[node]
197+
node for node in causal_dag.nodes if nx.get_node_attributes(causal_dag, "test", default=True)[node]
198198
]
199199

200200
if not causal_dag.is_acyclic() and ignore_cycles:
@@ -214,7 +214,7 @@ def generate_causal_tests(dag_path: str, output_path: str, ignore_cycles: bool =
214214
tests = [
215215
relation.to_json_stub(skip=False)
216216
for relation in relations
217-
if len(list(causal_dag.graph.predecessors(relation.base_test_case.outcome_variable))) > 0
217+
if len(list(causal_dag.predecessors(relation.base_test_case.outcome_variable))) > 0
218218
]
219219

220220
logger.info(f"Generated {len(tests)} tests. Saving to {output_path}.")

tests/main_tests/test_main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_load_data_query(self):
9393
def test_load_dag_missing_node(self):
9494
framework = CausalTestingFramework(self.paths)
9595
framework.setup()
96-
framework.dag.graph.add_node("missing")
96+
framework.dag.add_node("missing")
9797
with self.assertRaises(ValueError):
9898
framework.create_variables()
9999

tests/specification_tests/test_causal_dag.py

Lines changed: 10 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import os
33
import shutil, tempfile
44
import networkx as nx
5-
from causal_testing.specification.causal_dag import CausalDAG, close_separator, list_all_min_sep, CausalDAG
6-
from causal_testing.specification.optimised_causal_dag import CausalDAG as OptimisedCausalDAG
5+
from causal_testing.specification.causal_dag import CausalDAG, close_separator, list_all_min_sep
76
from causal_testing.specification.scenario import Scenario
87
from causal_testing.specification.variable import Input, Output
98
from causal_testing.testing.base_test_case import BaseTestCase
@@ -26,7 +25,7 @@ def test_enumerate_minimal_adjustment_sets(self):
2625
causal_dag = CausalDAG(self.dag_dot_path)
2726
xs, ys = ["X"], ["Y"]
2827
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
29-
self.assertEqual([{"Z"}], adjustment_sets)
28+
self.assertEqual([{"Z"}], list(adjustment_sets))
3029

3130
def tearDown(self) -> None:
3231
shutil.rmtree(self.temp_dir_path)
@@ -47,19 +46,19 @@ def test_valid_iv(self):
4746

4847
def test_unrelated_instrument(self):
4948
causal_dag = CausalDAG(self.dag_dot_path)
50-
causal_dag.graph.remove_edge("I", "X")
49+
causal_dag.remove_edge("I", "X")
5150
with self.assertRaises(ValueError):
5251
causal_dag.check_iv_assumptions("X", "Y", "I")
5352

5453
def test_direct_cause(self):
5554
causal_dag = CausalDAG(self.dag_dot_path)
56-
causal_dag.graph.add_edge("I", "Y")
55+
causal_dag.add_edge("I", "Y")
5756
with self.assertRaises(ValueError):
5857
causal_dag.check_iv_assumptions("X", "Y", "I")
5958

6059
def test_common_cause(self):
6160
causal_dag = CausalDAG(self.dag_dot_path)
62-
causal_dag.graph.add_edge("U", "I")
61+
causal_dag.add_edge("U", "I")
6362
with self.assertRaises(ValueError):
6463
causal_dag.check_iv_assumptions("X", "Y", "I")
6564

@@ -280,12 +279,12 @@ def test_enumerate_minimal_adjustment_sets(self):
280279
causal_dag = CausalDAG(self.dag_dot_path)
281280
xs, ys = ["X1", "X2"], ["Y"]
282281
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
283-
self.assertEqual([{"Z"}], adjustment_sets)
282+
self.assertEqual([{"Z"}], list(adjustment_sets))
284283

285284
def test_enumerate_minimal_adjustment_sets_multiple(self):
286285
"""Test whether enumerate_minimal_adjustment_sets lists all minimum adjustment sets if multiple are possible."""
287286
causal_dag = CausalDAG()
288-
causal_dag.graph.add_edges_from(
287+
causal_dag.add_edges_from(
289288
[
290289
("X1", "X2"),
291290
("X2", "V"),
@@ -309,7 +308,7 @@ def test_enumerate_minimal_adjustment_sets_multiple(self):
309308
def test_enumerate_minimal_adjustment_sets_two_adjustments(self):
310309
"""Test whether enumerate_minimal_adjustment_sets lists all possible minimum adjustment sets of arity two."""
311310
causal_dag = CausalDAG()
312-
causal_dag.graph.add_edges_from(
311+
causal_dag.add_edges_from(
313312
[
314313
("X1", "X2"),
315314
("X2", "V"),
@@ -336,7 +335,7 @@ def test_enumerate_minimal_adjustment_sets_two_adjustments(self):
336335
def test_dag_with_non_character_nodes(self):
337336
"""Test identification for a DAG whose nodes are not just characters (strings of length greater than 1)."""
338337
causal_dag = CausalDAG()
339-
causal_dag.graph.add_edges_from(
338+
causal_dag.add_edges_from(
340339
[
341340
("va", "ba"),
342341
("ba", "ia"),
@@ -351,7 +350,7 @@ def test_dag_with_non_character_nodes(self):
351350
)
352351
xs, ys = ["ba"], ["da"]
353352
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
354-
self.assertEqual(adjustment_sets, [{"aa"}, {"la"}, {"va"}])
353+
self.assertEqual(list(adjustment_sets), [{"aa"}, {"la"}, {"va"}])
355354

356355
def tearDown(self) -> None:
357356
shutil.rmtree(self.temp_dir_path)
@@ -485,148 +484,3 @@ def time_it(label, func, *args, **kwargs):
485484
result = func(*args, **kwargs)
486485
print(f"{label} took {time.time() - start:.6f} seconds")
487486
return result
488-
489-
490-
class TestOptimisedDAGIdentification(TestDAGIdentification):
491-
"""
492-
Test the Causal DAG identification algorithms and supporting algorithms.
493-
"""
494-
495-
def test_is_min_adjustment_for_not_min_adjustment(self):
496-
"""Test whether is_min_adjustment can correctly test whether the minimum adjustment set is not minimal."""
497-
causal_dag = CausalDAG(self.dag_dot_path)
498-
xs, ys, zs = ["X1", "X2"], ["Y"], {"Z", "V"}
499-
500-
opt_dag = OptimisedCausalDAG(self.dag_dot_path)
501-
502-
norm_result = time_it("Norm", lambda: causal_dag.adjustment_set_is_minimal(xs, ys, zs))
503-
opt_result = time_it("Opt", lambda: opt_dag.adjustment_set_is_minimal(xs, ys, zs))
504-
self.assertEqual(norm_result, opt_result)
505-
506-
def test_is_min_adjustment_for_invalid_adjustment(self):
507-
"""Test whether is min_adjustment can correctly identify that the minimum adjustment set is invalid."""
508-
causal_dag = OptimisedCausalDAG(self.dag_dot_path)
509-
xs, ys, zs = ["X1", "X2"], ["Y"], set()
510-
self.assertRaises(ValueError, causal_dag.adjustment_set_is_minimal, xs, ys, zs)
511-
512-
def test_get_ancestor_graph_of_causal_dag(self):
513-
"""Test whether get_ancestor_graph converts a CausalDAG to the correct ancestor graph."""
514-
causal_dag = OptimisedCausalDAG(self.dag_dot_path)
515-
xs, ys = ["X1", "X2"], ["Y"]
516-
ancestor_graph = causal_dag.get_ancestor_graph(xs, ys)
517-
self.assertEqual(list(ancestor_graph.nodes), ["X1", "X2", "D1", "Y", "Z"])
518-
self.assertEqual(
519-
list(ancestor_graph.edges),
520-
[("X1", "X2"), ("X2", "D1"), ("D1", "Y"), ("Z", "X2"), ("Z", "Y")],
521-
)
522-
523-
def test_get_ancestor_graph_of_proper_backdoor_graph(self):
524-
"""Test whether get_ancestor_graph converts a CausalDAG to the correct proper back-door graph."""
525-
causal_dag = OptimisedCausalDAG(self.dag_dot_path)
526-
xs, ys = ["X1", "X2"], ["Y"]
527-
proper_backdoor_graph = causal_dag.get_proper_backdoor_graph(xs, ys)
528-
ancestor_graph = proper_backdoor_graph.get_ancestor_graph(xs, ys)
529-
self.assertEqual(list(ancestor_graph.nodes), ["X1", "X2", "D1", "Y", "Z"])
530-
self.assertEqual(
531-
list(ancestor_graph.edges),
532-
[("X1", "X2"), ("D1", "Y"), ("Z", "X2"), ("Z", "Y")],
533-
)
534-
535-
def test_enumerate_minimal_adjustment_sets(self):
536-
"""Test whether enumerate_minimal_adjustment_sets lists all possible minimum sized adjustment sets."""
537-
causal_dag = OptimisedCausalDAG(self.dag_dot_path)
538-
xs, ys = ["X1", "X2"], ["Y"]
539-
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
540-
self.assertEqual([{"Z"}], list(adjustment_sets))
541-
542-
def test_enumerate_minimal_adjustment_sets_multiple(self):
543-
"""Test whether enumerate_minimal_adjustment_sets lists all minimum adjustment sets if multiple are possible."""
544-
causal_dag = OptimisedCausalDAG()
545-
causal_dag.add_edges_from(
546-
[
547-
("X1", "X2"),
548-
("X2", "V"),
549-
("Z1", "X2"),
550-
("Z1", "Z2"),
551-
("Z2", "Z3"),
552-
("Z3", "Y"),
553-
("D1", "Y"),
554-
("D1", "D2"),
555-
("Y", "D3"),
556-
]
557-
)
558-
opt_causal_dag = OptimisedCausalDAG()
559-
opt_causal_dag.add_edges_from(
560-
[
561-
("X1", "X2"),
562-
("X2", "V"),
563-
("Z1", "X2"),
564-
("Z1", "Z2"),
565-
("Z2", "Z3"),
566-
("Z3", "Y"),
567-
("D1", "Y"),
568-
("D1", "D2"),
569-
("Y", "D3"),
570-
]
571-
)
572-
xs, ys = ["X1", "X2"], ["Y"]
573-
574-
norm_adjustment_sets = time_it("Norm", lambda: causal_dag.enumerate_minimal_adjustment_sets(xs, ys))
575-
576-
opt_adjustment_sets = time_it("Opt", lambda: opt_causal_dag.enumerate_minimal_adjustment_sets(xs, ys))
577-
set_of_opt_adjustment_sets = set(frozenset(min_separator) for min_separator in opt_adjustment_sets)
578-
579-
self.assertEqual(
580-
{frozenset({"Z1"}), frozenset({"Z2"}), frozenset({"Z3"})},
581-
set_of_opt_adjustment_sets,
582-
)
583-
584-
def test_enumerate_minimal_adjustment_sets_two_adjustments(self):
585-
"""Test whether enumerate_minimal_adjustment_sets lists all possible minimum adjustment sets of arity two."""
586-
causal_dag = OptimisedCausalDAG()
587-
causal_dag.add_edges_from(
588-
[
589-
("X1", "X2"),
590-
("X2", "V"),
591-
("Z1", "X2"),
592-
("Z1", "Z2"),
593-
("Z2", "Z3"),
594-
("Z3", "Y"),
595-
("D1", "Y"),
596-
("D1", "D2"),
597-
("Y", "D3"),
598-
("Z4", "X1"),
599-
("Z4", "Y"),
600-
("X2", "D1"),
601-
]
602-
)
603-
xs, ys = ["X1", "X2"], ["Y"]
604-
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
605-
set_of_adjustment_sets = set(frozenset(min_separator) for min_separator in adjustment_sets)
606-
self.assertEqual(
607-
{frozenset({"Z1", "Z4"}), frozenset({"Z2", "Z4"}), frozenset({"Z3", "Z4"})},
608-
set_of_adjustment_sets,
609-
)
610-
611-
def test_dag_with_non_character_nodes(self):
612-
"""Test identification for a DAG whose nodes are not just characters (strings of length greater than 1)."""
613-
causal_dag = OptimisedCausalDAG()
614-
causal_dag.add_edges_from(
615-
[
616-
("va", "ba"),
617-
("ba", "ia"),
618-
("ba", "da"),
619-
("ba", "ra"),
620-
("la", "va"),
621-
("la", "aa"),
622-
("aa", "ia"),
623-
("aa", "da"),
624-
("aa", "ra"),
625-
]
626-
)
627-
xs, ys = ["ba"], ["da"]
628-
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
629-
self.assertEqual(list(adjustment_sets), [{"aa"}, {"la"}, {"va"}])
630-
631-
def tearDown(self) -> None:
632-
shutil.rmtree(self.temp_dir_path)

tests/testing_tests/test_metamorphic_relations.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_should_not_cause_json_stub(self):
4848
"""Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program
4949
and there is only a single input."""
5050
causal_dag = CausalDAG(self.dag_dot_path)
51-
causal_dag.graph.remove_nodes_from(["X2", "X3"])
51+
causal_dag.remove_nodes_from(["X2", "X3"])
5252
adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0])
5353
should_not_cause_MR = ShouldNotCause(BaseTestCase("X1", "Z"), adj_set)
5454
self.assertEqual(
@@ -70,7 +70,7 @@ def test_should_cause_json_stub(self):
7070
"""Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program
7171
and there is only a single input."""
7272
causal_dag = CausalDAG(self.dag_dot_path)
73-
causal_dag.graph.remove_nodes_from(["X2", "X3"])
73+
causal_dag.remove_nodes_from(["X2", "X3"])
7474
adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0])
7575
should_cause_MR = ShouldCause(BaseTestCase("X1", "Z"), adj_set)
7676
self.assertEqual(
@@ -218,8 +218,7 @@ def test_generate_causal_tests_ignore_cycles(self):
218218
map(
219219
lambda x: x.to_json_stub(skip=False),
220220
filter(
221-
lambda relation: len(list(dcg.graph.predecessors(relation.base_test_case.outcome_variable)))
222-
> 0,
221+
lambda relation: len(list(dcg.predecessors(relation.base_test_case.outcome_variable))) > 0,
223222
relations,
224223
),
225224
)
@@ -238,8 +237,7 @@ def test_generate_causal_tests(self):
238237
map(
239238
lambda x: x.to_json_stub(skip=False),
240239
filter(
241-
lambda relation: len(list(dag.graph.predecessors(relation.base_test_case.outcome_variable)))
242-
> 0,
240+
lambda relation: len(list(dag.predecessors(relation.base_test_case.outcome_variable))) > 0,
243241
relations,
244242
),
245243
)

0 commit comments

Comments
 (0)