Skip to content

Commit 671eb75

Browse files
authored
Merge branch 'main' into alpha
2 parents 9271cdd + 0a5c924 commit 671eb75

File tree

4 files changed

+18
-6
lines changed

4 files changed

+18
-6
lines changed

causal_testing/testing/causal_test_case.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(
3030
control_value: Any = None,
3131
treatment_value: Any = None,
3232
estimate_type: str = "ate",
33+
estimate_params: dict = None,
3334
effect_modifier_configuration: dict[Variable:Any] = None,
3435
):
3536
"""
@@ -47,6 +48,8 @@ def __init__(
4748
self.treatment_variable = base_test_case.treatment_variable
4849
self.treatment_value = treatment_value
4950
self.estimate_type = estimate_type
51+
if estimate_params is None:
52+
self.estimate_params = {}
5053
self.effect = base_test_case.effect
5154

5255
if effect_modifier_configuration:

causal_testing/testing/causal_test_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ def _return_causal_test_results(self, estimator, causal_test_case):
162162
)
163163
elif causal_test_case.estimate_type == "risk_ratio":
164164
logger.debug("calculating risk_ratio")
165-
risk_ratio, confidence_intervals = estimator.estimate_risk_ratio()
165+
risk_ratio, confidence_intervals = estimator.estimate_risk_ratio(**causal_test_case.estimate_params)
166+
166167
causal_test_result = CausalTestResult(
167168
estimator=estimator,
168169
test_value=TestValue("risk_ratio", risk_ratio),

causal_testing/testing/estimators.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def estimate(self, data: pd.DataFrame, adjustment_config: dict = None) -> Regres
181181
# x = x[model.params.index]
182182
return model.predict(x)
183183

184-
def estimate_control_treatment(self, bootstrap_size=100, adjustment_config=None) -> tuple[pd.Series, pd.Series]:
184+
def estimate_control_treatment(self, bootstrap_size, adjustment_config) -> tuple[pd.Series, pd.Series]:
185185
"""Estimate the outcomes under control and treatment.
186186
187187
:return: The estimated control and treatment values and their confidence
@@ -217,14 +217,18 @@ def estimate_control_treatment(self, bootstrap_size=100, adjustment_config=None)
217217

218218
return (y.iloc[1], np.array(control)), (y.iloc[0], np.array(treatment))
219219

220-
def estimate_ate(self, bootstrap_size=100, adjustment_config=None) -> float:
220+
def estimate_ate(self, estimator_params: dict = None) -> float:
221221
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
222222
by changing the treatment variable from the control value to the treatment value. Here, we actually
223223
calculate the expected outcomes under control and treatment and take one away from the other. This
224224
allows for custom terms to be put in such as squares, inverses, products, etc.
225225
226226
:return: The estimated average treatment effect and 95% confidence intervals
227227
"""
228+
if estimator_params is None:
229+
estimator_params = {}
230+
bootstrap_size = estimator_params.get("bootstrap_size", 100)
231+
adjustment_config = estimator_params.get("adjustment_config", None)
228232
(control_outcome, control_bootstraps), (
229233
treatment_outcome,
230234
treatment_bootstraps,
@@ -247,14 +251,18 @@ def estimate_ate(self, bootstrap_size=100, adjustment_config=None) -> float:
247251

248252
return estimate, (ci_low, ci_high)
249253

250-
def estimate_risk_ratio(self, bootstrap_size=100, adjustment_config=None) -> float:
254+
def estimate_risk_ratio(self, estimator_params: dict = None) -> float:
251255
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
252256
by changing the treatment variable from the control value to the treatment value. Here, we actually
253257
calculate the expected outcomes under control and treatment and divide one by the other. This
254258
allows for custom terms to be put in such as squares, inverses, products, etc.
255259
256260
:return: The estimated risk ratio and 95% confidence intervals.
257261
"""
262+
if estimator_params is None:
263+
estimator_params = {}
264+
bootstrap_size = estimator_params.get("bootstrap_size", 100)
265+
adjustment_config = estimator_params.get("adjustment_config", None)
258266
(control_outcome, control_bootstraps), (
259267
treatment_outcome,
260268
treatment_bootstraps,

tests/testing_tests/test_estimators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,14 @@ def test_ate_adjustment(self):
124124
logistic_regression_estimator = LogisticRegressionEstimator(
125125
"length_in", 65, 55, {"large_gauge"}, "completed", df
126126
)
127-
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config={"large_gauge": 0})
127+
ate, _ = logistic_regression_estimator.estimate_ate(estimator_params={"adjustment_config": {"large_gauge": 0}})
128128
self.assertEqual(round(ate, 4), -0.3388)
129129

130130
def test_ate_invalid_adjustment(self):
131131
df = self.scarf_df.copy()
132132
logistic_regression_estimator = LogisticRegressionEstimator("length_in", 65, 55, {}, "completed", df)
133133
with self.assertRaises(ValueError):
134-
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config={"large_gauge": 0})
134+
ate, _ = logistic_regression_estimator.estimate_ate(estimator_params={"adjustment_config": {"large_gauge": 0}})
135135

136136
def test_ate_effect_modifiers(self):
137137
df = self.scarf_df.copy()

0 commit comments

Comments
 (0)