Skip to content

Commit e4ec9d6

Browse files
Merge remote-tracking branch 'origin/adjustment_set_formula_check' into adjustment_set_formula_check
2 parents 10e4cb2 + 410da1c commit e4ec9d6

File tree

9 files changed

+150
-65
lines changed

9 files changed

+150
-65
lines changed

causal_testing/testing/causal_test_adequacy.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,26 @@ def __init__(
2626
self.test_suite = test_suite
2727
self.tested_pairs = None
2828
self.pairs_to_test = None
29-
self.untested_edges = None
29+
self.untested_pairs = None
3030
self.dag_adequacy = None
3131

3232
def measure_adequacy(self):
3333
"""
34-
Calculate the adequacy measurement, and populate the `dat_adequacy` field.
34+
Calculate the adequacy measurement, and populate the `dag_adequacy` field.
3535
"""
36-
self.tested_pairs = {(t.treatment_variable, t.outcome_variable) for t in self.test_suite}
37-
self.pairs_to_test = set(combinations(self.causal_dag.graph.nodes, 2))
38-
self.untested_edges = self.pairs_to_test.difference(self.tested_pairs)
36+
self.pairs_to_test = set(combinations(self.causal_dag.graph.nodes(), 2))
37+
self.tested_pairs = set()
38+
39+
for n1, n2 in self.pairs_to_test:
40+
if (n1, n2) in self.causal_dag.graph.edges():
41+
if any((t.treatment_variable, t.outcome_variable) == (n1, n2) for t in self.test_suite):
42+
self.tested_pairs.add((n1, n2))
43+
else:
44+
# Causal independences are not order dependent
45+
if any((t.treatment_variable, t.outcome_variable) in {(n1, n2), (n2, n1)} for t in self.test_suite):
46+
self.tested_pairs.add((n1, n2))
47+
48+
self.untested_pairs = self.pairs_to_test.difference(self.tested_pairs)
3949
self.dag_adequacy = len(self.tested_pairs) / len(self.pairs_to_test)
4050

4151
def to_dict(self):
@@ -45,7 +55,7 @@ def to_dict(self):
4555
"test_suite": self.test_suite,
4656
"tested_pairs": self.tested_pairs,
4757
"pairs_to_test": self.pairs_to_test,
48-
"untested_edges": self.untested_edges,
58+
"untested_pairs": self.untested_pairs,
4959
"dag_adequacy": self.dag_adequacy,
5060
}
5161

causal_testing/testing/causal_test_case.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,22 +57,6 @@ def __init__(
5757
else:
5858
self.effect_modifier_configuration = {}
5959

60-
def get_treatment_variable(self):
61-
"""Return the treatment variable name (as string) for this causal test case"""
62-
return self.treatment_variable.name
63-
64-
def get_outcome_variable(self):
65-
"""Return the outcome variable name (as string) for this causal test case."""
66-
return self.outcome_variable.name
67-
68-
def get_control_value(self):
69-
"""Return a the control value of the treatment variable in this causal test case."""
70-
return self.control_value
71-
72-
def get_treatment_value(self):
73-
"""Return the treatment value of the treatment variable in this causal test case."""
74-
return self.treatment_value
75-
7660
def execute_test(self, estimator: type(Estimator), data_collector: DataCollector) -> CausalTestResult:
7761
"""Execute a causal test case and return the causal test result.
7862

causal_testing/testing/causal_test_result.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,11 @@ def to_dict(self, json=False):
8585
"outcome": self.estimator.outcome,
8686
"adjustment_set": list(self.adjustment_set) if json else self.adjustment_set,
8787
"effect_measure": self.test_value.type,
88-
"effect_estimate": self.test_value.value,
89-
"ci_low": self.ci_low(),
90-
"ci_high": self.ci_high(),
88+
"effect_estimate": self.test_value.value.to_dict()
89+
if json and hasattr(self.test_value.value, "to_dict")
90+
else self.test_value.value,
91+
"ci_low": self.ci_low().to_dict() if json and hasattr(self.ci_low(), "to_dict") else self.ci_low(),
92+
"ci_high": self.ci_high().to_dict() if json and hasattr(self.ci_high(), "to_dict") else self.ci_high(),
9193
}
9294
if self.adequacy:
9395
base_dict["adequacy"] = self.adequacy.to_dict()

causal_testing/testing/estimators.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,16 @@ def estimate(self, data: pd.DataFrame, adjustment_config: dict = None) -> Regres
161161
# x = x[model.params.index]
162162
return model.predict(x)
163163

164-
def estimate_control_treatment(self, bootstrap_size, adjustment_config) -> tuple[pd.Series, pd.Series]:
164+
def estimate_control_treatment(
165+
self, adjustment_config: dict = None, bootstrap_size: int = 100
166+
) -> tuple[pd.Series, pd.Series]:
165167
"""Estimate the outcomes under control and treatment.
166168
167169
:return: The estimated control and treatment values and their confidence
168170
intervals in the form ((ci_low, control, ci_high), (ci_low, treatment, ci_high)).
169171
"""
170-
172+
if adjustment_config is None:
173+
adjustment_config = {}
171174
y = self.estimate(self.df, adjustment_config=adjustment_config)
172175

173176
try:
@@ -197,18 +200,16 @@ def estimate_control_treatment(self, bootstrap_size, adjustment_config) -> tuple
197200

198201
return (y.iloc[1], np.array(control)), (y.iloc[0], np.array(treatment))
199202

200-
def estimate_ate(self, estimator_params: dict = None) -> float:
203+
def estimate_ate(self, adjustment_config: dict = None, bootstrap_size: int = 100) -> float:
201204
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
202205
by changing the treatment variable from the control value to the treatment value. Here, we actually
203206
calculate the expected outcomes under control and treatment and take one away from the other. This
204207
allows for custom terms to be put in such as squares, inverses, products, etc.
205208
206209
:return: The estimated average treatment effect and 95% confidence intervals
207210
"""
208-
if estimator_params is None:
209-
estimator_params = {}
210-
bootstrap_size = estimator_params.get("bootstrap_size", 100)
211-
adjustment_config = estimator_params.get("adjustment_config", None)
211+
if adjustment_config is None:
212+
adjustment_config = {}
212213
(control_outcome, control_bootstraps), (
213214
treatment_outcome,
214215
treatment_bootstraps,
@@ -231,18 +232,16 @@ def estimate_ate(self, estimator_params: dict = None) -> float:
231232

232233
return estimate, (ci_low, ci_high)
233234

234-
def estimate_risk_ratio(self, estimator_params: dict = None) -> float:
235+
def estimate_risk_ratio(self, adjustment_config: dict = None, bootstrap_size: int = 100) -> float:
235236
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
236237
by changing the treatment variable from the control value to the treatment value. Here, we actually
237238
calculate the expected outcomes under control and treatment and divide one by the other. This
238239
allows for custom terms to be put in such as squares, inverses, products, etc.
239240
240241
:return: The estimated risk ratio and 95% confidence intervals.
241242
"""
242-
if estimator_params is None:
243-
estimator_params = {}
244-
bootstrap_size = estimator_params.get("bootstrap_size", 100)
245-
adjustment_config = estimator_params.get("adjustment_config", None)
243+
if adjustment_config is None:
244+
adjustment_config = {}
246245
(control_outcome, control_bootstraps), (
247246
treatment_outcome,
248247
treatment_bootstraps,
@@ -374,7 +373,6 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
374373
"""
375374
if adjustment_config is None:
376375
adjustment_config = {}
377-
378376
model = self._run_linear_regression()
379377

380378
x = pd.DataFrame(columns=self.df.columns)
@@ -393,13 +391,15 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
393391

394392
return y.iloc[1], y.iloc[0]
395393

396-
def estimate_risk_ratio(self) -> tuple[float, list[float, float]]:
394+
def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[float, list[float, float]]:
397395
"""Estimate the risk_ratio effect of the treatment on the outcome. That is, the change in outcome caused
398396
by changing the treatment variable from the control value to the treatment value.
399397
400398
:return: The average treatment effect and the 95% Wald confidence intervals.
401399
"""
402-
control_outcome, treatment_outcome = self.estimate_control_treatment()
400+
if adjustment_config is None:
401+
adjustment_config = {}
402+
control_outcome, treatment_outcome = self.estimate_control_treatment(adjustment_config=adjustment_config)
403403
ci_low = treatment_outcome["mean_ci_lower"] / control_outcome["mean_ci_upper"]
404404
ci_high = treatment_outcome["mean_ci_upper"] / control_outcome["mean_ci_lower"]
405405

@@ -413,6 +413,8 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[float
413413
414414
:return: The average treatment effect and the 95% Wald confidence intervals.
415415
"""
416+
if adjustment_config is None:
417+
adjustment_config = {}
416418
control_outcome, treatment_outcome = self.estimate_control_treatment(adjustment_config=adjustment_config)
417419
ci_low = treatment_outcome["mean_ci_lower"] - control_outcome["mean_ci_upper"]
418420
ci_high = treatment_outcome["mean_ci_upper"] - control_outcome["mean_ci_lower"]

examples/poisson-line-process/example_poisson_process.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def causal_test_intensity_num_shapes(
9090
# 8. Set up an estimator
9191
data = pd.read_csv(observational_data_path)
9292

93-
treatment = causal_test_case.get_treatment_variable()
94-
outcome = causal_test_case.get_outcome_variable()
93+
treatment = causal_test_case.treatment_variable.name
94+
outcome = causal_test_case.outcome_variable.name
9595

9696
estimator = None
9797
if empirical:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ dependencies = [
2020
"fitter~=1.4",
2121
"lhsmdu~=1.1",
2222
"networkx~=2.6",
23-
"numpy~=1.22.0",
23+
"numpy~=1.23",
2424
"pandas~=1.3",
2525
"scikit_learn~=1.1",
2626
"scipy~=1.7",

tests/testing_tests/test_causal_test_adequacy.py

Lines changed: 105 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ def test_data_adequacy_cateogorical(self):
107107
{"kurtosis": {"test_input_no_dist[T.b]": 0.0}, "bootstrap_size": 100, "passing": 100},
108108
)
109109

110-
def test_dag_adequacy(self):
110+
def test_dag_adequacy_dependent(self):
111111
base_test_case = BaseTestCase(
112112
treatment_variable="test_input",
113-
outcome_variable="test_output",
113+
outcome_variable="B",
114114
effect=None,
115115
)
116116
causal_test_case = CausalTestCase(
@@ -128,29 +128,127 @@ def test_dag_adequacy(self):
128128
{
129129
"causal_dag": self.json_class.causal_specification.causal_dag,
130130
"test_suite": test_suite,
131-
"tested_pairs": {("test_input", "test_output")},
131+
"tested_pairs": {("test_input", "B")},
132132
"pairs_to_test": {
133+
("B", "C"),
133134
("test_input_no_dist", "test_input"),
135+
("C", "test_output"),
134136
("test_input", "B"),
135-
("test_input_no_dist", "C"),
136137
("test_input_no_dist", "B"),
138+
("test_input", "test_output"),
137139
("test_input", "C"),
138-
("B", "C"),
139140
("test_input_no_dist", "test_output"),
141+
("B", "test_output"),
142+
("test_input_no_dist", "C"),
143+
},
144+
"untested_pairs": {
145+
("B", "C"),
146+
("test_input_no_dist", "test_input"),
147+
("C", "test_output"),
148+
("test_input_no_dist", "B"),
140149
("test_input", "test_output"),
150+
("test_input", "C"),
151+
("test_input_no_dist", "test_output"),
152+
("B", "test_output"),
153+
("test_input_no_dist", "C"),
154+
},
155+
"dag_adequacy": 0.1,
156+
},
157+
)
158+
159+
def test_dag_adequacy_independent(self):
160+
base_test_case = BaseTestCase(
161+
treatment_variable="test_input",
162+
outcome_variable="C",
163+
effect=None,
164+
)
165+
causal_test_case = CausalTestCase(
166+
base_test_case=base_test_case,
167+
expected_causal_effect=None,
168+
estimate_type=None,
169+
)
170+
test_suite = CausalTestSuite()
171+
test_suite.add_test_object(base_test_case, causal_test_case, None, None)
172+
dag_adequacy = DAGAdequacy(self.json_class.causal_specification.causal_dag, test_suite)
173+
dag_adequacy.measure_adequacy()
174+
print(dag_adequacy.to_dict())
175+
self.assertEqual(
176+
dag_adequacy.to_dict(),
177+
{
178+
"causal_dag": self.json_class.causal_specification.causal_dag,
179+
"test_suite": test_suite,
180+
"tested_pairs": {("test_input", "C")},
181+
"pairs_to_test": {
182+
("B", "C"),
183+
("test_input_no_dist", "test_input"),
141184
("C", "test_output"),
185+
("test_input", "B"),
186+
("test_input_no_dist", "B"),
187+
("test_input", "test_output"),
188+
("test_input", "C"),
189+
("test_input_no_dist", "test_output"),
142190
("B", "test_output"),
191+
("test_input_no_dist", "C"),
143192
},
144-
"untested_edges": {
193+
"untested_pairs": {
194+
("B", "C"),
145195
("test_input_no_dist", "test_input"),
196+
("C", "test_output"),
197+
("test_input_no_dist", "B"),
198+
("test_input", "test_output"),
146199
("test_input", "B"),
200+
("test_input_no_dist", "test_output"),
201+
("B", "test_output"),
147202
("test_input_no_dist", "C"),
203+
},
204+
"dag_adequacy": 0.1,
205+
},
206+
)
207+
208+
def test_dag_adequacy_independent_other_way(self):
209+
base_test_case = BaseTestCase(
210+
treatment_variable="C",
211+
outcome_variable="test_input",
212+
effect=None,
213+
)
214+
causal_test_case = CausalTestCase(
215+
base_test_case=base_test_case,
216+
expected_causal_effect=None,
217+
estimate_type=None,
218+
)
219+
test_suite = CausalTestSuite()
220+
test_suite.add_test_object(base_test_case, causal_test_case, None, None)
221+
dag_adequacy = DAGAdequacy(self.json_class.causal_specification.causal_dag, test_suite)
222+
dag_adequacy.measure_adequacy()
223+
print(dag_adequacy.to_dict())
224+
self.assertEqual(
225+
dag_adequacy.to_dict(),
226+
{
227+
"causal_dag": self.json_class.causal_specification.causal_dag,
228+
"test_suite": test_suite,
229+
"tested_pairs": {("test_input", "C")},
230+
"pairs_to_test": {
231+
("B", "C"),
232+
("test_input_no_dist", "test_input"),
233+
("C", "test_output"),
234+
("test_input", "B"),
148235
("test_input_no_dist", "B"),
236+
("test_input", "test_output"),
149237
("test_input", "C"),
150-
("B", "C"),
151238
("test_input_no_dist", "test_output"),
239+
("B", "test_output"),
240+
("test_input_no_dist", "C"),
241+
},
242+
"untested_pairs": {
243+
("B", "C"),
244+
("test_input_no_dist", "test_input"),
152245
("C", "test_output"),
246+
("test_input_no_dist", "B"),
247+
("test_input", "test_output"),
248+
("test_input", "B"),
249+
("test_input_no_dist", "test_output"),
153250
("B", "test_output"),
251+
("test_input_no_dist", "C"),
154252
},
155253
"dag_adequacy": 0.1,
156254
},

tests/testing_tests/test_causal_test_case.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,6 @@ def setUp(self) -> None:
3737
treatment_value=1,
3838
)
3939

40-
def test_get_treatment_variable(self):
41-
self.assertEqual(self.causal_test_case.get_treatment_variable(), "A")
42-
43-
def test_get_outcome_variable(self):
44-
self.assertEqual(self.causal_test_case.get_outcome_variable(), "C")
45-
46-
def test_get_treatment_value(self):
47-
self.assertEqual(self.causal_test_case.get_treatment_value(), 1)
48-
49-
def test_get_control_value(self):
50-
self.assertEqual(self.causal_test_case.get_control_value(), 0)
51-
5240
def test_str(self):
5341
self.assertEqual(
5442
str(self.causal_test_case),

tests/testing_tests/test_estimators.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,15 @@ def test_ate_adjustment(self):
124124
logistic_regression_estimator = LogisticRegressionEstimator(
125125
"length_in", 65, 55, {"large_gauge"}, "completed", df
126126
)
127-
ate, _ = logistic_regression_estimator.estimate_ate(estimator_params={"adjustment_config": {"large_gauge": 0}})
127+
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config = {"large_gauge": 0})
128128
self.assertEqual(round(ate, 4), -0.3388)
129129

130130
def test_ate_invalid_adjustment(self):
131131
df = self.scarf_df.copy()
132132
logistic_regression_estimator = LogisticRegressionEstimator("length_in", 65, 55, {}, "completed", df)
133133
with self.assertRaises(ValueError):
134134
ate, _ = logistic_regression_estimator.estimate_ate(
135-
estimator_params={"adjustment_config": {"large_gauge": 0}}
135+
adjustment_config = {"large_gauge": 0}
136136
)
137137

138138
def test_ate_effect_modifiers(self):
@@ -392,8 +392,9 @@ def test_program_15_no_interaction_ate_calculated(self):
392392
)
393393
# terms_to_square = ["age", "wt71", "smokeintensity", "smokeyrs"]
394394
# for term_to_square in terms_to_square:
395+
395396
ate, [ci_low, ci_high] = linear_regression_estimator.estimate_ate_calculated(
396-
{k: self.nhefs_df.mean()[k] for k in covariates}
397+
adjustment_config = {k: self.nhefs_df.mean()[k] for k in covariates}
397398
)
398399
self.assertEqual(round(ate, 1), 3.5)
399400
self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [1.9, 5])

0 commit comments

Comments
 (0)