Skip to content

Commit 0ad05e2

Browse files
authored
Merge pull request #337 from CITCOM-project/fandango_experiment
Experimental optimisations, proposed by ChatGPT.
2 parents 756203c + 9816b98 commit 0ad05e2

File tree

8 files changed

+220
-240
lines changed

8 files changed

+220
-240
lines changed

causal_testing/main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def load_dag(self) -> CausalDAG:
140140
"""
141141
logger.info(f"Loading DAG from {self.paths.dag_path}")
142142
dag = CausalDAG(str(self.paths.dag_path), ignore_cycles=self.ignore_cycles)
143-
logger.info(f"DAG loaded with {len(dag.graph.nodes)} nodes and {len(dag.graph.edges)} edges")
143+
logger.info(f"DAG loaded with {len(dag.nodes)} nodes and {len(dag.edges)} edges")
144144
return dag
145145

146146
def _read_dataframe(self, data_path):
@@ -172,18 +172,18 @@ def create_variables(self) -> None:
172172
"""
173173
Create variable objects from DAG nodes based on their connectivity.
174174
"""
175-
for node_name, node_data in self.dag.graph.nodes(data=True):
175+
for node_name, node_data in self.dag.nodes(data=True):
176176
if node_name not in self.data.columns and not node_data.get("hidden", False):
177177
raise ValueError(f"Node {node_name} missing from data. Should it be marked as hidden?")
178178

179179
dtype = self.data.dtypes.get(node_name)
180180

181181
# If node has no incoming edges, it's an input
182-
if self.dag.graph.in_degree(node_name) == 0:
182+
if self.dag.in_degree(node_name) == 0:
183183
self.variables["inputs"][node_name] = Input(name=node_name, datatype=dtype)
184184

185185
# Otherwise it's an output
186-
if self.dag.graph.in_degree(node_name) > 0:
186+
if self.dag.in_degree(node_name) > 0:
187187
self.variables["outputs"][node_name] = Output(name=node_name, datatype=dtype)
188188

189189
def create_scenario_and_specification(self) -> None:

causal_testing/specification/causal_dag.py

Lines changed: 164 additions & 213 deletions
Large diffs are not rendered by default.

causal_testing/surrogate/causal_surrogate_assisted.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def generate_surrogates(
125125
surrogate_models = []
126126

127127
for u, v in specification.causal_dag.edges:
128-
edge_metadata = specification.causal_dag.graph.adj[u][v]
128+
edge_metadata = specification.causal_dag.adj[u][v]
129129
if "included" in edge_metadata:
130130
from_var = specification.scenario.variables.get(u)
131131
to_var = specification.scenario.variables.get(v)

causal_testing/testing/metamorphic_relation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,13 @@ def generate_metamorphic_relation(
133133
# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
134134
if ((u, v) not in dag.edges) and ((v, u) not in dag.edges):
135135
# Case 1: U --> ... --> V
136-
if u in nx.ancestors(dag.graph, v):
136+
if u in nx.ancestors(dag, v):
137137
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
138138
if adj_sets:
139139
metamorphic_relations.append(ShouldNotCause(BaseTestCase(u, v), list(adj_sets[0])))
140140

141141
# Case 2: V --> ... --> U
142-
elif v in nx.ancestors(dag.graph, u):
142+
elif v in nx.ancestors(dag, u):
143143
adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore)
144144
if adj_sets:
145145
metamorphic_relations.append(ShouldNotCause(BaseTestCase(v, u), list(adj_sets[0])))
@@ -221,7 +221,7 @@ def generate_causal_tests(
221221
causal_dag = CausalDAG(dag_path, ignore_cycles=ignore_cycles)
222222

223223
dag_nodes_to_test = [
224-
node for node in causal_dag.nodes if nx.get_node_attributes(causal_dag.graph, "test", default=True)[node]
224+
node for node in causal_dag.nodes if nx.get_node_attributes(causal_dag, "test", default=True)[node]
225225
]
226226

227227
if not causal_dag.is_acyclic() and ignore_cycles:
@@ -241,7 +241,7 @@ def generate_causal_tests(
241241
tests = [
242242
relation.to_json_stub(**json_stub_kargs)
243243
for relation in relations
244-
if len(list(causal_dag.graph.predecessors(relation.base_test_case.outcome_variable))) > 0
244+
if len(list(causal_dag.predecessors(relation.base_test_case.outcome_variable))) > 0
245245
]
246246

247247
logger.info(f"Generated {len(tests)} tests. Saving to {output_path}.")

tests/main_tests/test_main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_load_data_query(self):
9393
def test_load_dag_missing_node(self):
9494
framework = CausalTestingFramework(self.paths)
9595
framework.setup()
96-
framework.dag.graph.add_node("missing")
96+
framework.dag.add_node("missing")
9797
with self.assertRaises(ValueError):
9898
framework.create_variables()
9999

tests/resources/data/dag.xml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<graphml xmlns="http://graphml.graphdrawing.org/xmlns"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
4+
xsi:schemaLocation="http://graphml.graphdrawing.org/xmlns
5+
http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd">
6+
<graph id="G" edgedefault="directed">
7+
<node id="Z"/>
8+
<node id="X"/>
9+
<node id="M"/>
10+
<node id="Y"/>
11+
<edge source="Z" target="X"/>
12+
<edge source="X" target="M"/>
13+
<edge source="M" target="Y"/>
14+
<edge source="Z" target="M"/>
15+
</graph>
16+
</graphml>

tests/specification_tests/test_causal_dag.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,18 @@ def setUp(self) -> None:
2020
with open(self.dag_dot_path, "w") as f:
2121
f.write(dag_dot)
2222

23+
def test_graphml(self):
24+
dot_dag = CausalDAG(self.dag_dot_path)
25+
xml_dag = CausalDAG(os.path.join("tests", "resources", "data", "dag.xml"))
26+
self.assertEqual(dot_dag.nodes, xml_dag.nodes)
27+
self.assertEqual(dot_dag.edges, xml_dag.edges)
28+
2329
def test_enumerate_minimal_adjustment_sets(self):
2430
"""Test whether enumerate_minimal_adjustment_sets lists all possible minimum sized adjustment sets."""
2531
causal_dag = CausalDAG(self.dag_dot_path)
2632
xs, ys = ["X"], ["Y"]
2733
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
28-
self.assertEqual([{"Z"}], adjustment_sets)
34+
self.assertEqual([{"Z"}], list(adjustment_sets))
2935

3036
def tearDown(self) -> None:
3137
shutil.rmtree(self.temp_dir_path)
@@ -46,19 +52,19 @@ def test_valid_iv(self):
4652

4753
def test_unrelated_instrument(self):
4854
causal_dag = CausalDAG(self.dag_dot_path)
49-
causal_dag.graph.remove_edge("I", "X")
55+
causal_dag.remove_edge("I", "X")
5056
with self.assertRaises(ValueError):
5157
causal_dag.check_iv_assumptions("X", "Y", "I")
5258

5359
def test_direct_cause(self):
5460
causal_dag = CausalDAG(self.dag_dot_path)
55-
causal_dag.graph.add_edge("I", "Y")
61+
causal_dag.add_edge("I", "Y")
5662
with self.assertRaises(ValueError):
5763
causal_dag.check_iv_assumptions("X", "Y", "I")
5864

5965
def test_common_cause(self):
6066
causal_dag = CausalDAG(self.dag_dot_path)
61-
causal_dag.graph.add_edge("U", "I")
67+
causal_dag.add_edge("U", "I")
6268
with self.assertRaises(ValueError):
6369
causal_dag.check_iv_assumptions("X", "Y", "I")
6470

@@ -279,12 +285,12 @@ def test_enumerate_minimal_adjustment_sets(self):
279285
causal_dag = CausalDAG(self.dag_dot_path)
280286
xs, ys = ["X1", "X2"], ["Y"]
281287
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
282-
self.assertEqual([{"Z"}], adjustment_sets)
288+
self.assertEqual([{"Z"}], list(adjustment_sets))
283289

284290
def test_enumerate_minimal_adjustment_sets_multiple(self):
285291
"""Test whether enumerate_minimal_adjustment_sets lists all minimum adjustment sets if multiple are possible."""
286292
causal_dag = CausalDAG()
287-
causal_dag.graph.add_edges_from(
293+
causal_dag.add_edges_from(
288294
[
289295
("X1", "X2"),
290296
("X2", "V"),
@@ -308,7 +314,7 @@ def test_enumerate_minimal_adjustment_sets_multiple(self):
308314
def test_enumerate_minimal_adjustment_sets_two_adjustments(self):
309315
"""Test whether enumerate_minimal_adjustment_sets lists all possible minimum adjustment sets of arity two."""
310316
causal_dag = CausalDAG()
311-
causal_dag.graph.add_edges_from(
317+
causal_dag.add_edges_from(
312318
[
313319
("X1", "X2"),
314320
("X2", "V"),
@@ -335,7 +341,7 @@ def test_enumerate_minimal_adjustment_sets_two_adjustments(self):
335341
def test_dag_with_non_character_nodes(self):
336342
"""Test identification for a DAG whose nodes are not just characters (strings of length greater than 1)."""
337343
causal_dag = CausalDAG()
338-
causal_dag.graph.add_edges_from(
344+
causal_dag.add_edges_from(
339345
[
340346
("va", "ba"),
341347
("ba", "ia"),
@@ -350,7 +356,7 @@ def test_dag_with_non_character_nodes(self):
350356
)
351357
xs, ys = ["ba"], ["da"]
352358
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
353-
self.assertEqual(adjustment_sets, [{"aa"}, {"la"}, {"va"}])
359+
self.assertEqual(list(adjustment_sets), [{"aa"}, {"la"}, {"va"}])
354360

355361
def tearDown(self) -> None:
356362
shutil.rmtree(self.temp_dir_path)
@@ -475,3 +481,12 @@ def test_hidden_varaible_adjustment_sets(self):
475481

476482
def tearDown(self) -> None:
477483
shutil.rmtree(self.temp_dir_path)
484+
485+
486+
def time_it(label, func, *args, **kwargs):
487+
import time
488+
489+
start = time.time()
490+
result = func(*args, **kwargs)
491+
print(f"{label} took {time.time() - start:.6f} seconds")
492+
return result

tests/testing_tests/test_metamorphic_relations.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_should_not_cause_json_stub(self):
4848
"""Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program
4949
and there is only a single input."""
5050
causal_dag = CausalDAG(self.dag_dot_path)
51-
causal_dag.graph.remove_nodes_from(["X2", "X3"])
51+
causal_dag.remove_nodes_from(["X2", "X3"])
5252
adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0])
5353
should_not_cause_mr = ShouldNotCause(BaseTestCase("X1", "Z"), adj_set)
5454
self.assertEqual(
@@ -70,7 +70,7 @@ def test_should_not_cause_logistic_json_stub(self):
7070
"""Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program
7171
and there is only a single input."""
7272
causal_dag = CausalDAG(self.dag_dot_path)
73-
causal_dag.graph.remove_nodes_from(["X2", "X3"])
73+
causal_dag.remove_nodes_from(["X2", "X3"])
7474
adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0])
7575
should_not_cause_mr = ShouldNotCause(BaseTestCase("X1", "Z"), adj_set)
7676
self.assertEqual(
@@ -94,7 +94,7 @@ def test_should_cause_json_stub(self):
9494
"""Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program
9595
and there is only a single input."""
9696
causal_dag = CausalDAG(self.dag_dot_path)
97-
causal_dag.graph.remove_nodes_from(["X2", "X3"])
97+
causal_dag.remove_nodes_from(["X2", "X3"])
9898
adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0])
9999
should_cause_mr = ShouldCause(BaseTestCase("X1", "Z"), adj_set)
100100
self.assertEqual(
@@ -115,7 +115,7 @@ def test_should_cause_logistic_json_stub(self):
115115
"""Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program
116116
and there is only a single input."""
117117
causal_dag = CausalDAG(self.dag_dot_path)
118-
causal_dag.graph.remove_nodes_from(["X2", "X3"])
118+
causal_dag.remove_nodes_from(["X2", "X3"])
119119
adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0])
120120
should_cause_mr = ShouldCause(BaseTestCase("X1", "Z"), adj_set)
121121
self.assertEqual(
@@ -265,8 +265,7 @@ def test_generate_causal_tests_ignore_cycles(self):
265265
map(
266266
lambda x: x.to_json_stub(skip=True),
267267
filter(
268-
lambda relation: len(list(dcg.graph.predecessors(relation.base_test_case.outcome_variable)))
269-
> 0,
268+
lambda relation: len(list(dcg.predecessors(relation.base_test_case.outcome_variable))) > 0,
270269
relations,
271270
),
272271
)
@@ -285,8 +284,7 @@ def test_generate_causal_tests(self):
285284
map(
286285
lambda x: x.to_json_stub(skip=True),
287286
filter(
288-
lambda relation: len(list(dag.graph.predecessors(relation.base_test_case.outcome_variable)))
289-
> 0,
287+
lambda relation: len(list(dag.predecessors(relation.base_test_case.outcome_variable))) > 0,
290288
relations,
291289
),
292290
)

0 commit comments

Comments
 (0)