Skip to content

Commit c38706d

Browse files
committed
Fixed default behaviour of MR generation
1 parent 2786fdc commit c38706d

File tree

5 files changed

+35
-27
lines changed

5 files changed

+35
-27
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,14 @@ def __init__(self, dot_path: str = None, ignore_cycles: bool = False, **attr):
151151
else:
152152
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")
153153

154+
@property
155+
def nodes(self):
156+
return self.graph.nodes
157+
158+
@property
159+
def edges(self):
160+
return self.graph.edges
161+
154162
def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
155163
"""
156164
Checks the three instrumental variable assumptions, raising a
@@ -170,7 +178,7 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
170178

171179
# (iii) Instrument and outcome do not share causes
172180

173-
for cause in self.graph.nodes:
181+
for cause in self.nodes:
174182
# Exclude self-cycles due to breaking changes in NetworkX > 3.2
175183
outcome_paths = (
176184
list(nx.all_simple_paths(self.graph, source=cause, target=outcome)) if cause != outcome else []
@@ -222,8 +230,8 @@ def get_proper_backdoor_graph(self, treatments: list[str], outcomes: list[str])
222230
:return: A CausalDAG corresponding to the proper back-door graph.
223231
"""
224232
for var in treatments + outcomes:
225-
if var not in self.graph.nodes:
226-
raise IndexError(f"{var} not a node in Causal DAG.\nValid nodes are{self.graph.nodes}.")
233+
if var not in self.nodes:
234+
raise IndexError(f"{var} not a node in Causal DAG.\nValid nodes are{self.nodes}.")
227235

228236
proper_backdoor_graph = self.copy()
229237
nodes_on_proper_causal_path = proper_backdoor_graph.proper_causal_pathway(treatments, outcomes)
@@ -255,7 +263,7 @@ def get_ancestor_graph(self, treatments: list[str], outcomes: list[str]) -> Caus
255263
*[nx.ancestors(ancestor_graph.graph, outcome).union({outcome}) for outcome in outcomes]
256264
)
257265
variables_to_keep = treatment_ancestors.union(outcome_ancestors)
258-
variables_to_remove = set(self.graph.nodes).difference(variables_to_keep)
266+
variables_to_remove = set(self.nodes).difference(variables_to_keep)
259267
ancestor_graph.graph.remove_nodes_from(variables_to_remove)
260268
return ancestor_graph
261269

@@ -273,7 +281,7 @@ def get_indirect_graph(self, treatments: list[str], outcomes: list[str]) -> Caus
273281
ee = []
274282
for s in treatments:
275283
for t in outcomes:
276-
if (s, t) in gback.graph.edges:
284+
if (s, t) in gback.edges:
277285
ee.append((s, t))
278286
for v1, v2 in ee:
279287
gback.graph.remove_edge(v1, v2)
@@ -451,7 +459,7 @@ def constructive_backdoor_criterion(
451459
]
452460
)
453461

454-
if not set(covariates).issubset(set(self.graph.nodes).difference(descendents_of_proper_casual_paths)):
462+
if not set(covariates).issubset(set(self.nodes).difference(descendents_of_proper_casual_paths)):
455463
logger.info(
456464
"Failed Condition 1: Z=%s **is** a descendent of some variable on a proper causal "
457465
"path between X=%s and Y=%s.",
@@ -566,9 +574,9 @@ def to_dot_string(self) -> str:
566574
:return DOT string of the DAG.
567575
"""
568576
dotstring = "digraph G {\n"
569-
dotstring += "".join([f"{a} -> {b};\n" for a, b in self.graph.edges])
577+
dotstring += "".join([f"{a} -> {b};\n" for a, b in self.edges])
570578
dotstring += "}"
571579
return dotstring
572580

573581
def __str__(self):
574-
return f"Nodes: {self.graph.nodes}\nEdges: {self.graph.edges}"
582+
return f"Nodes: {self.nodes}\nEdges: {self.edges}"

causal_testing/surrogate/causal_surrogate_assisted.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def generate_surrogates(
121121
"""
122122
surrogate_models = []
123123

124-
for u, v in specification.causal_dag.graph.edges:
124+
for u, v in specification.causal_dag.edges:
125125
edge_metadata = specification.causal_dag.graph.adj[u][v]
126126
if "included" in edge_metadata:
127127
from_var = specification.scenario.variables.get(u)

causal_testing/testing/causal_test_adequacy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ def measure_adequacy(self):
3838
"""
3939
Calculate the adequacy measurement, and populate the `dag_adequacy` field.
4040
"""
41-
self.pairs_to_test = set(combinations(self.causal_dag.graph.nodes(), 2))
41+
self.pairs_to_test = set(combinations(self.causal_dag.nodes, 2))
4242
self.tested_pairs = set()
4343

4444
for n1, n2 in self.pairs_to_test:
45-
if (n1, n2) in self.causal_dag.graph.edges():
45+
if (n1, n2) in self.causal_dag.edges():
4646
if any((t.treatment_variable, t.outcome_variable) == (n1, n2) for t in self.test_suite):
4747
self.tested_pairs.add((n1, n2))
4848
else:

causal_testing/testing/metamorphic_relation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def generate_metamorphic_relation(
108108
metamorphic_relations = []
109109

110110
# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
111-
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
111+
if ((u, v) not in dag.edges) and ((v, u) not in dag.edges):
112112
# Case 1: U --> ... --> V
113113
if u in nx.ancestors(dag.graph, v):
114114
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
@@ -129,7 +129,7 @@ def generate_metamorphic_relation(
129129
metamorphic_relations.append(ShouldNotCause(BaseTestCase(u, v), list(adj_sets[0])))
130130

131131
# Create a ShouldCause relation for each edge (u, v) or (v, u)
132-
elif (u, v) in dag.graph.edges:
132+
elif (u, v) in dag.edges:
133133
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
134134
if adj_sets:
135135
metamorphic_relations.append(ShouldCause(BaseTestCase(u, v), list(adj_sets[0])))
@@ -160,7 +160,7 @@ def generate_metamorphic_relations(
160160
nodes_to_ignore = {}
161161

162162
if nodes_to_test is None:
163-
nodes_to_test = dag.graph.nodes
163+
nodes_to_test = dag.nodes
164164

165165
if not threads:
166166
metamorphic_relations = [
@@ -205,9 +205,9 @@ def generate_metamorphic_relations(
205205

206206
causal_dag = CausalDAG(args.dag_path, ignore_cycles=args.ignore_cycles)
207207

208-
dag_nodes_to_test = set(
209-
k for k, v in nx.get_node_attributes(causal_dag.graph, "test", default=True).items() if v == "True"
210-
)
208+
dag_nodes_to_test = [
209+
node for node in causal_dag.nodes if nx.get_node_attributes(causal_dag.graph, "test", default=True)[node]
210+
]
211211

212212
if not causal_dag.is_acyclic() and args.ignore_cycles:
213213
logger.warning(

tests/specification_tests/test_causal_dag.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def test_valid_causal_dag(self):
8686
"""Test whether the Causal DAG is valid."""
8787
causal_dag = CausalDAG(self.dag_dot_path)
8888
print(causal_dag)
89-
assert list(causal_dag.graph.nodes) == ["A", "B", "C", "D"] and list(causal_dag.graph.edges) == [
89+
assert list(causal_dag.nodes) == ["A", "B", "C", "D"] and list(causal_dag.edges) == [
9090
("A", "B"),
9191
("B", "C"),
9292
("D", "A"),
@@ -101,7 +101,7 @@ def test_invalid_causal_dag(self):
101101
def test_empty_casual_dag(self):
102102
"""Test whether an empty dag can be created."""
103103
causal_dag = CausalDAG()
104-
assert list(causal_dag.graph.nodes) == [] and list(causal_dag.graph.edges) == []
104+
assert list(causal_dag.nodes) == [] and list(causal_dag.edges) == []
105105

106106
def test_to_dot_string(self):
107107
causal_dag = CausalDAG(self.dag_dot_path)
@@ -174,10 +174,10 @@ def setUp(self) -> None:
174174
def test_get_indirect_graph(self):
175175
causal_dag = CausalDAG(self.dag_dot_path)
176176
indirect_graph = causal_dag.get_indirect_graph(["D1"], ["Y"])
177-
original_edges = list(causal_dag.graph.edges)
177+
original_edges = list(causal_dag.edges)
178178
original_edges.remove(("D1", "Y"))
179-
self.assertEqual(list(indirect_graph.graph.edges), original_edges)
180-
self.assertEqual(indirect_graph.graph.nodes, causal_dag.graph.nodes)
179+
self.assertEqual(list(indirect_graph.edges), original_edges)
180+
self.assertEqual(indirect_graph.nodes, causal_dag.nodes)
181181

182182
def test_proper_backdoor_graph(self):
183183
"""Test whether converting a Causal DAG to a proper back-door graph works correctly."""
@@ -195,7 +195,7 @@ def test_proper_backdoor_graph(self):
195195
("Z", "Y"),
196196
]
197197
)
198-
self.assertTrue(set(proper_backdoor_graph.graph.edges).issubset(edges))
198+
self.assertTrue(set(proper_backdoor_graph.edges).issubset(edges))
199199

200200
def test_constructive_backdoor_criterion_should_hold(self):
201201
"""Test whether the constructive criterion holds when it should."""
@@ -246,9 +246,9 @@ def test_get_ancestor_graph_of_causal_dag(self):
246246
causal_dag = CausalDAG(self.dag_dot_path)
247247
xs, ys = ["X1", "X2"], ["Y"]
248248
ancestor_graph = causal_dag.get_ancestor_graph(xs, ys)
249-
self.assertEqual(list(ancestor_graph.graph.nodes), ["X1", "X2", "D1", "Y", "Z"])
249+
self.assertEqual(list(ancestor_graph.nodes), ["X1", "X2", "D1", "Y", "Z"])
250250
self.assertEqual(
251-
list(ancestor_graph.graph.edges),
251+
list(ancestor_graph.edges),
252252
[("X1", "X2"), ("X2", "D1"), ("D1", "Y"), ("Z", "X2"), ("Z", "Y")],
253253
)
254254

@@ -258,9 +258,9 @@ def test_get_ancestor_graph_of_proper_backdoor_graph(self):
258258
xs, ys = ["X1", "X2"], ["Y"]
259259
proper_backdoor_graph = causal_dag.get_proper_backdoor_graph(xs, ys)
260260
ancestor_graph = proper_backdoor_graph.get_ancestor_graph(xs, ys)
261-
self.assertEqual(list(ancestor_graph.graph.nodes), ["X1", "X2", "D1", "Y", "Z"])
261+
self.assertEqual(list(ancestor_graph.nodes), ["X1", "X2", "D1", "Y", "Z"])
262262
self.assertEqual(
263-
list(ancestor_graph.graph.edges),
263+
list(ancestor_graph.edges),
264264
[("X1", "X2"), ("D1", "Y"), ("Z", "X2"), ("Z", "Y")],
265265
)
266266

0 commit comments

Comments
 (0)