Skip to content

Commit 478fcd7

Browse files
Merge branch 'main' into adjustment_set_formula_check
2 parents 2d02e5b + ee8be4d commit 478fcd7

File tree

7 files changed

+129
-47
lines changed

7 files changed

+129
-47
lines changed

causal_testing/testing/causal_test_adequacy.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,26 @@ def __init__(
2626
self.test_suite = test_suite
2727
self.tested_pairs = None
2828
self.pairs_to_test = None
29-
self.untested_edges = None
29+
self.untested_pairs = None
3030
self.dag_adequacy = None
3131

3232
def measure_adequacy(self):
3333
"""
34-
Calculate the adequacy measurement, and populate the `dat_adequacy` field.
34+
Calculate the adequacy measurement, and populate the `dag_adequacy` field.
3535
"""
36-
self.tested_pairs = {(t.treatment_variable, t.outcome_variable) for t in self.test_suite}
37-
self.pairs_to_test = set(combinations(self.causal_dag.graph.nodes, 2))
38-
self.untested_edges = self.pairs_to_test.difference(self.tested_pairs)
36+
self.pairs_to_test = set(combinations(self.causal_dag.graph.nodes(), 2))
37+
self.tested_pairs = set()
38+
39+
for n1, n2 in self.pairs_to_test:
40+
if (n1, n2) in self.causal_dag.graph.edges():
41+
if any((t.treatment_variable, t.outcome_variable) == (n1, n2) for t in self.test_suite):
42+
self.tested_pairs.add((n1, n2))
43+
else:
44+
# Causal independences are not order dependent
45+
if any((t.treatment_variable, t.outcome_variable) in {(n1, n2), (n2, n1)} for t in self.test_suite):
46+
self.tested_pairs.add((n1, n2))
47+
48+
self.untested_pairs = self.pairs_to_test.difference(self.tested_pairs)
3949
self.dag_adequacy = len(self.tested_pairs) / len(self.pairs_to_test)
4050

4151
def to_dict(self):
@@ -45,7 +55,7 @@ def to_dict(self):
4555
"test_suite": self.test_suite,
4656
"tested_pairs": self.tested_pairs,
4757
"pairs_to_test": self.pairs_to_test,
48-
"untested_edges": self.untested_edges,
58+
"untested_pairs": self.untested_pairs,
4959
"dag_adequacy": self.dag_adequacy,
5060
}
5161

causal_testing/testing/causal_test_case.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,22 +57,6 @@ def __init__(
5757
else:
5858
self.effect_modifier_configuration = {}
5959

60-
def get_treatment_variable(self):
61-
"""Return the treatment variable name (as string) for this causal test case"""
62-
return self.treatment_variable.name
63-
64-
def get_outcome_variable(self):
65-
"""Return the outcome variable name (as string) for this causal test case."""
66-
return self.outcome_variable.name
67-
68-
def get_control_value(self):
69-
"""Return a the control value of the treatment variable in this causal test case."""
70-
return self.control_value
71-
72-
def get_treatment_value(self):
73-
"""Return the treatment value of the treatment variable in this causal test case."""
74-
return self.treatment_value
75-
7660
def execute_test(self, estimator: type(Estimator), data_collector: DataCollector) -> CausalTestResult:
7761
"""Execute a causal test case and return the causal test result.
7862

causal_testing/testing/causal_test_result.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,11 @@ def to_dict(self, json=False):
8585
"outcome": self.estimator.outcome,
8686
"adjustment_set": list(self.adjustment_set) if json else self.adjustment_set,
8787
"effect_measure": self.test_value.type,
88-
"effect_estimate": self.test_value.value,
89-
"ci_low": self.ci_low(),
90-
"ci_high": self.ci_high(),
88+
"effect_estimate": self.test_value.value.to_dict()
89+
if json and hasattr(self.test_value.value, "to_dict")
90+
else self.test_value.value,
91+
"ci_low": self.ci_low().to_dict() if json and hasattr(self.ci_low(), "to_dict") else self.ci_low(),
92+
"ci_high": self.ci_high().to_dict() if json and hasattr(self.ci_high(), "to_dict") else self.ci_high(),
9193
}
9294
if self.adequacy:
9395
base_dict["adequacy"] = self.adequacy.to_dict()

examples/poisson-line-process/example_poisson_process.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def causal_test_intensity_num_shapes(
9090
# 8. Set up an estimator
9191
data = pd.read_csv(observational_data_path)
9292

93-
treatment = causal_test_case.get_treatment_variable()
94-
outcome = causal_test_case.get_outcome_variable()
93+
treatment = causal_test_case.treatment_variable.name
94+
outcome = causal_test_case.outcome_variable.name
9595

9696
estimator = None
9797
if empirical:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ dependencies = [
2020
"fitter~=1.4",
2121
"lhsmdu~=1.1",
2222
"networkx~=2.6",
23-
"numpy~=1.22.0",
23+
"numpy~=1.23",
2424
"pandas~=1.3",
2525
"scikit_learn~=1.1",
2626
"scipy~=1.7",

tests/testing_tests/test_causal_test_adequacy.py

Lines changed: 105 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ def test_data_adequacy_cateogorical(self):
107107
{"kurtosis": {"test_input_no_dist[T.b]": 0.0}, "bootstrap_size": 100, "passing": 100},
108108
)
109109

110-
def test_dag_adequacy(self):
110+
def test_dag_adequacy_dependent(self):
111111
base_test_case = BaseTestCase(
112112
treatment_variable="test_input",
113-
outcome_variable="test_output",
113+
outcome_variable="B",
114114
effect=None,
115115
)
116116
causal_test_case = CausalTestCase(
@@ -128,29 +128,127 @@ def test_dag_adequacy(self):
128128
{
129129
"causal_dag": self.json_class.causal_specification.causal_dag,
130130
"test_suite": test_suite,
131-
"tested_pairs": {("test_input", "test_output")},
131+
"tested_pairs": {("test_input", "B")},
132132
"pairs_to_test": {
133+
("B", "C"),
133134
("test_input_no_dist", "test_input"),
135+
("C", "test_output"),
134136
("test_input", "B"),
135-
("test_input_no_dist", "C"),
136137
("test_input_no_dist", "B"),
138+
("test_input", "test_output"),
137139
("test_input", "C"),
138-
("B", "C"),
139140
("test_input_no_dist", "test_output"),
141+
("B", "test_output"),
142+
("test_input_no_dist", "C"),
143+
},
144+
"untested_pairs": {
145+
("B", "C"),
146+
("test_input_no_dist", "test_input"),
147+
("C", "test_output"),
148+
("test_input_no_dist", "B"),
140149
("test_input", "test_output"),
150+
("test_input", "C"),
151+
("test_input_no_dist", "test_output"),
152+
("B", "test_output"),
153+
("test_input_no_dist", "C"),
154+
},
155+
"dag_adequacy": 0.1,
156+
},
157+
)
158+
159+
def test_dag_adequacy_independent(self):
160+
base_test_case = BaseTestCase(
161+
treatment_variable="test_input",
162+
outcome_variable="C",
163+
effect=None,
164+
)
165+
causal_test_case = CausalTestCase(
166+
base_test_case=base_test_case,
167+
expected_causal_effect=None,
168+
estimate_type=None,
169+
)
170+
test_suite = CausalTestSuite()
171+
test_suite.add_test_object(base_test_case, causal_test_case, None, None)
172+
dag_adequacy = DAGAdequacy(self.json_class.causal_specification.causal_dag, test_suite)
173+
dag_adequacy.measure_adequacy()
174+
print(dag_adequacy.to_dict())
175+
self.assertEqual(
176+
dag_adequacy.to_dict(),
177+
{
178+
"causal_dag": self.json_class.causal_specification.causal_dag,
179+
"test_suite": test_suite,
180+
"tested_pairs": {("test_input", "C")},
181+
"pairs_to_test": {
182+
("B", "C"),
183+
("test_input_no_dist", "test_input"),
141184
("C", "test_output"),
185+
("test_input", "B"),
186+
("test_input_no_dist", "B"),
187+
("test_input", "test_output"),
188+
("test_input", "C"),
189+
("test_input_no_dist", "test_output"),
142190
("B", "test_output"),
191+
("test_input_no_dist", "C"),
143192
},
144-
"untested_edges": {
193+
"untested_pairs": {
194+
("B", "C"),
145195
("test_input_no_dist", "test_input"),
196+
("C", "test_output"),
197+
("test_input_no_dist", "B"),
198+
("test_input", "test_output"),
146199
("test_input", "B"),
200+
("test_input_no_dist", "test_output"),
201+
("B", "test_output"),
147202
("test_input_no_dist", "C"),
203+
},
204+
"dag_adequacy": 0.1,
205+
},
206+
)
207+
208+
def test_dag_adequacy_independent_other_way(self):
209+
base_test_case = BaseTestCase(
210+
treatment_variable="C",
211+
outcome_variable="test_input",
212+
effect=None,
213+
)
214+
causal_test_case = CausalTestCase(
215+
base_test_case=base_test_case,
216+
expected_causal_effect=None,
217+
estimate_type=None,
218+
)
219+
test_suite = CausalTestSuite()
220+
test_suite.add_test_object(base_test_case, causal_test_case, None, None)
221+
dag_adequacy = DAGAdequacy(self.json_class.causal_specification.causal_dag, test_suite)
222+
dag_adequacy.measure_adequacy()
223+
print(dag_adequacy.to_dict())
224+
self.assertEqual(
225+
dag_adequacy.to_dict(),
226+
{
227+
"causal_dag": self.json_class.causal_specification.causal_dag,
228+
"test_suite": test_suite,
229+
"tested_pairs": {("test_input", "C")},
230+
"pairs_to_test": {
231+
("B", "C"),
232+
("test_input_no_dist", "test_input"),
233+
("C", "test_output"),
234+
("test_input", "B"),
148235
("test_input_no_dist", "B"),
236+
("test_input", "test_output"),
149237
("test_input", "C"),
150-
("B", "C"),
151238
("test_input_no_dist", "test_output"),
239+
("B", "test_output"),
240+
("test_input_no_dist", "C"),
241+
},
242+
"untested_pairs": {
243+
("B", "C"),
244+
("test_input_no_dist", "test_input"),
152245
("C", "test_output"),
246+
("test_input_no_dist", "B"),
247+
("test_input", "test_output"),
248+
("test_input", "B"),
249+
("test_input_no_dist", "test_output"),
153250
("B", "test_output"),
251+
("test_input_no_dist", "C"),
154252
},
155253
"dag_adequacy": 0.1,
156254
},

tests/testing_tests/test_causal_test_case.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,6 @@ def setUp(self) -> None:
3737
treatment_value=1,
3838
)
3939

40-
def test_get_treatment_variable(self):
41-
self.assertEqual(self.causal_test_case.get_treatment_variable(), "A")
42-
43-
def test_get_outcome_variable(self):
44-
self.assertEqual(self.causal_test_case.get_outcome_variable(), "C")
45-
46-
def test_get_treatment_value(self):
47-
self.assertEqual(self.causal_test_case.get_treatment_value(), 1)
48-
49-
def test_get_control_value(self):
50-
self.assertEqual(self.causal_test_case.get_control_value(), 0)
51-
5240
def test_str(self):
5341
self.assertEqual(
5442
str(self.causal_test_case),

0 commit comments

Comments
 (0)