Skip to content

Experimental optimisations, proposed by ChatGPT. #337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Aug 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions causal_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def load_dag(self) -> CausalDAG:
"""
logger.info(f"Loading DAG from {self.paths.dag_path}")
dag = CausalDAG(str(self.paths.dag_path), ignore_cycles=self.ignore_cycles)
logger.info(f"DAG loaded with {len(dag.graph.nodes)} nodes and {len(dag.graph.edges)} edges")
logger.info(f"DAG loaded with {len(dag.nodes)} nodes and {len(dag.edges)} edges")
return dag

def _read_dataframe(self, data_path):
Expand Down Expand Up @@ -172,18 +172,18 @@ def create_variables(self) -> None:
"""
Create variable objects from DAG nodes based on their connectivity.
"""
for node_name, node_data in self.dag.graph.nodes(data=True):
for node_name, node_data in self.dag.nodes(data=True):
if node_name not in self.data.columns and not node_data.get("hidden", False):
raise ValueError(f"Node {node_name} missing from data. Should it be marked as hidden?")

dtype = self.data.dtypes.get(node_name)

# If node has no incoming edges, it's an input
if self.dag.graph.in_degree(node_name) == 0:
if self.dag.in_degree(node_name) == 0:
self.variables["inputs"][node_name] = Input(name=node_name, datatype=dtype)

# Otherwise it's an output
if self.dag.graph.in_degree(node_name) > 0:
if self.dag.in_degree(node_name) > 0:
self.variables["outputs"][node_name] = Output(name=node_name, datatype=dtype)

def create_scenario_and_specification(self) -> None:
Expand Down
377 changes: 164 additions & 213 deletions causal_testing/specification/causal_dag.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion causal_testing/surrogate/causal_surrogate_assisted.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def generate_surrogates(
surrogate_models = []

for u, v in specification.causal_dag.edges:
edge_metadata = specification.causal_dag.graph.adj[u][v]
edge_metadata = specification.causal_dag.adj[u][v]
if "included" in edge_metadata:
from_var = specification.scenario.variables.get(u)
to_var = specification.scenario.variables.get(v)
Expand Down
8 changes: 4 additions & 4 deletions causal_testing/testing/metamorphic_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,13 @@ def generate_metamorphic_relation(
# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
if ((u, v) not in dag.edges) and ((v, u) not in dag.edges):
# Case 1: U --> ... --> V
if u in nx.ancestors(dag.graph, v):
if u in nx.ancestors(dag, v):
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
if adj_sets:
metamorphic_relations.append(ShouldNotCause(BaseTestCase(u, v), list(adj_sets[0])))

# Case 2: V --> ... --> U
elif v in nx.ancestors(dag.graph, u):
elif v in nx.ancestors(dag, u):
adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore)
if adj_sets:
metamorphic_relations.append(ShouldNotCause(BaseTestCase(v, u), list(adj_sets[0])))
Expand Down Expand Up @@ -221,7 +221,7 @@ def generate_causal_tests(
causal_dag = CausalDAG(dag_path, ignore_cycles=ignore_cycles)

dag_nodes_to_test = [
node for node in causal_dag.nodes if nx.get_node_attributes(causal_dag.graph, "test", default=True)[node]
node for node in causal_dag.nodes if nx.get_node_attributes(causal_dag, "test", default=True)[node]
]

if not causal_dag.is_acyclic() and ignore_cycles:
Expand All @@ -241,7 +241,7 @@ def generate_causal_tests(
tests = [
relation.to_json_stub(**json_stub_kargs)
for relation in relations
if len(list(causal_dag.graph.predecessors(relation.base_test_case.outcome_variable))) > 0
if len(list(causal_dag.predecessors(relation.base_test_case.outcome_variable))) > 0
]

logger.info(f"Generated {len(tests)} tests. Saving to {output_path}.")
Expand Down
2 changes: 1 addition & 1 deletion tests/main_tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_load_data_query(self):
def test_load_dag_missing_node(self):
framework = CausalTestingFramework(self.paths)
framework.setup()
framework.dag.graph.add_node("missing")
framework.dag.add_node("missing")
with self.assertRaises(ValueError):
framework.create_variables()

Expand Down
16 changes: 16 additions & 0 deletions tests/resources/data/dag.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
<?xml version="1.0" encoding="UTF-8"?>
<graphml xmlns="http://graphml.graphdrawing.org/xmlns"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://graphml.graphdrawing.org/xmlns
http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd">
<graph id="G" edgedefault="directed">
<node id="Z"/>
<node id="X"/>
<node id="M"/>
<node id="Y"/>
<edge source="Z" target="X"/>
<edge source="X" target="M"/>
<edge source="M" target="Y"/>
<edge source="Z" target="M"/>
</graph>
</graphml>
33 changes: 24 additions & 9 deletions tests/specification_tests/test_causal_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@ def setUp(self) -> None:
with open(self.dag_dot_path, "w") as f:
f.write(dag_dot)

def test_graphml(self):
dot_dag = CausalDAG(self.dag_dot_path)
xml_dag = CausalDAG(os.path.join("tests", "resources", "data", "dag.xml"))
self.assertEqual(dot_dag.nodes, xml_dag.nodes)
self.assertEqual(dot_dag.edges, xml_dag.edges)

def test_enumerate_minimal_adjustment_sets(self):
"""Test whether enumerate_minimal_adjustment_sets lists all possible minimum sized adjustment sets."""
causal_dag = CausalDAG(self.dag_dot_path)
xs, ys = ["X"], ["Y"]
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
self.assertEqual([{"Z"}], adjustment_sets)
self.assertEqual([{"Z"}], list(adjustment_sets))

def tearDown(self) -> None:
shutil.rmtree(self.temp_dir_path)
Expand All @@ -46,19 +52,19 @@ def test_valid_iv(self):

def test_unrelated_instrument(self):
causal_dag = CausalDAG(self.dag_dot_path)
causal_dag.graph.remove_edge("I", "X")
causal_dag.remove_edge("I", "X")
with self.assertRaises(ValueError):
causal_dag.check_iv_assumptions("X", "Y", "I")

def test_direct_cause(self):
causal_dag = CausalDAG(self.dag_dot_path)
causal_dag.graph.add_edge("I", "Y")
causal_dag.add_edge("I", "Y")
with self.assertRaises(ValueError):
causal_dag.check_iv_assumptions("X", "Y", "I")

def test_common_cause(self):
causal_dag = CausalDAG(self.dag_dot_path)
causal_dag.graph.add_edge("U", "I")
causal_dag.add_edge("U", "I")
with self.assertRaises(ValueError):
causal_dag.check_iv_assumptions("X", "Y", "I")

Expand Down Expand Up @@ -279,12 +285,12 @@ def test_enumerate_minimal_adjustment_sets(self):
causal_dag = CausalDAG(self.dag_dot_path)
xs, ys = ["X1", "X2"], ["Y"]
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
self.assertEqual([{"Z"}], adjustment_sets)
self.assertEqual([{"Z"}], list(adjustment_sets))

def test_enumerate_minimal_adjustment_sets_multiple(self):
"""Test whether enumerate_minimal_adjustment_sets lists all minimum adjustment sets if multiple are possible."""
causal_dag = CausalDAG()
causal_dag.graph.add_edges_from(
causal_dag.add_edges_from(
[
("X1", "X2"),
("X2", "V"),
Expand All @@ -308,7 +314,7 @@ def test_enumerate_minimal_adjustment_sets_multiple(self):
def test_enumerate_minimal_adjustment_sets_two_adjustments(self):
"""Test whether enumerate_minimal_adjustment_sets lists all possible minimum adjustment sets of arity two."""
causal_dag = CausalDAG()
causal_dag.graph.add_edges_from(
causal_dag.add_edges_from(
[
("X1", "X2"),
("X2", "V"),
Expand All @@ -335,7 +341,7 @@ def test_enumerate_minimal_adjustment_sets_two_adjustments(self):
def test_dag_with_non_character_nodes(self):
"""Test identification for a DAG whose nodes are not just characters (strings of length greater than 1)."""
causal_dag = CausalDAG()
causal_dag.graph.add_edges_from(
causal_dag.add_edges_from(
[
("va", "ba"),
("ba", "ia"),
Expand All @@ -350,7 +356,7 @@ def test_dag_with_non_character_nodes(self):
)
xs, ys = ["ba"], ["da"]
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
self.assertEqual(adjustment_sets, [{"aa"}, {"la"}, {"va"}])
self.assertEqual(list(adjustment_sets), [{"aa"}, {"la"}, {"va"}])

def tearDown(self) -> None:
shutil.rmtree(self.temp_dir_path)
Expand Down Expand Up @@ -475,3 +481,12 @@ def test_hidden_varaible_adjustment_sets(self):

def tearDown(self) -> None:
shutil.rmtree(self.temp_dir_path)


def time_it(label, func, *args, **kwargs):
import time

start = time.time()
result = func(*args, **kwargs)
print(f"{label} took {time.time() - start:.6f} seconds")
return result
14 changes: 6 additions & 8 deletions tests/testing_tests/test_metamorphic_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_should_not_cause_json_stub(self):
"""Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program
and there is only a single input."""
causal_dag = CausalDAG(self.dag_dot_path)
causal_dag.graph.remove_nodes_from(["X2", "X3"])
causal_dag.remove_nodes_from(["X2", "X3"])
adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0])
should_not_cause_mr = ShouldNotCause(BaseTestCase("X1", "Z"), adj_set)
self.assertEqual(
Expand All @@ -70,7 +70,7 @@ def test_should_not_cause_logistic_json_stub(self):
"""Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program
and there is only a single input."""
causal_dag = CausalDAG(self.dag_dot_path)
causal_dag.graph.remove_nodes_from(["X2", "X3"])
causal_dag.remove_nodes_from(["X2", "X3"])
adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0])
should_not_cause_mr = ShouldNotCause(BaseTestCase("X1", "Z"), adj_set)
self.assertEqual(
Expand All @@ -94,7 +94,7 @@ def test_should_cause_json_stub(self):
"""Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program
and there is only a single input."""
causal_dag = CausalDAG(self.dag_dot_path)
causal_dag.graph.remove_nodes_from(["X2", "X3"])
causal_dag.remove_nodes_from(["X2", "X3"])
adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0])
should_cause_mr = ShouldCause(BaseTestCase("X1", "Z"), adj_set)
self.assertEqual(
Expand All @@ -115,7 +115,7 @@ def test_should_cause_logistic_json_stub(self):
"""Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program
and there is only a single input."""
causal_dag = CausalDAG(self.dag_dot_path)
causal_dag.graph.remove_nodes_from(["X2", "X3"])
causal_dag.remove_nodes_from(["X2", "X3"])
adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0])
should_cause_mr = ShouldCause(BaseTestCase("X1", "Z"), adj_set)
self.assertEqual(
Expand Down Expand Up @@ -265,8 +265,7 @@ def test_generate_causal_tests_ignore_cycles(self):
map(
lambda x: x.to_json_stub(skip=True),
filter(
lambda relation: len(list(dcg.graph.predecessors(relation.base_test_case.outcome_variable)))
> 0,
lambda relation: len(list(dcg.predecessors(relation.base_test_case.outcome_variable))) > 0,
relations,
),
)
Expand All @@ -285,8 +284,7 @@ def test_generate_causal_tests(self):
map(
lambda x: x.to_json_stub(skip=True),
filter(
lambda relation: len(list(dag.graph.predecessors(relation.base_test_case.outcome_variable)))
> 0,
lambda relation: len(list(dag.predecessors(relation.base_test_case.outcome_variable))) > 0,
relations,
),
)
Expand Down
Loading