Skip to content

Commit f95a6f1

Browse files
Merge branch 'main' into json_concrete_param
2 parents 8bd0acd + 0a5c924 commit f95a6f1

File tree

4 files changed

+18
-6
lines changed

4 files changed

+18
-6
lines changed

causal_testing/testing/causal_test_case.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(
3030
control_value: Any = None,
3131
treatment_value: Any = None,
3232
estimate_type: str = "ate",
33+
estimate_params: dict = None,
3334
effect_modifier_configuration: dict[Variable:Any] = None,
3435
):
3536
"""
@@ -47,6 +48,8 @@ def __init__(
4748
self.treatment_variable = base_test_case.treatment_variable
4849
self.treatment_value = treatment_value
4950
self.estimate_type = estimate_type
51+
if estimate_params is None:
52+
self.estimate_params = {}
5053
self.effect = base_test_case.effect
5154

5255
if effect_modifier_configuration:

causal_testing/testing/causal_test_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ def _return_causal_test_results(self, estimator, causal_test_case):
162162
)
163163
elif causal_test_case.estimate_type == "risk_ratio":
164164
logger.debug("calculating risk_ratio")
165-
risk_ratio, confidence_intervals = estimator.estimate_risk_ratio()
165+
risk_ratio, confidence_intervals = estimator.estimate_risk_ratio(**causal_test_case.estimate_params)
166+
166167
causal_test_result = CausalTestResult(
167168
estimator=estimator,
168169
test_value=TestValue("risk_ratio", risk_ratio),

causal_testing/testing/estimators.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def estimate(self, data: pd.DataFrame, adjustment_config: dict = None) -> Regres
179179
# x = x[model.params.index]
180180
return model.predict(x)
181181

182-
def estimate_control_treatment(self, bootstrap_size=100, adjustment_config=None) -> tuple[pd.Series, pd.Series]:
182+
def estimate_control_treatment(self, bootstrap_size, adjustment_config) -> tuple[pd.Series, pd.Series]:
183183
"""Estimate the outcomes under control and treatment.
184184
185185
:return: The estimated control and treatment values and their confidence
@@ -215,14 +215,18 @@ def estimate_control_treatment(self, bootstrap_size=100, adjustment_config=None)
215215

216216
return (y.iloc[1], np.array(control)), (y.iloc[0], np.array(treatment))
217217

218-
def estimate_ate(self, bootstrap_size=100, adjustment_config=None) -> float:
218+
def estimate_ate(self, estimator_params: dict = None) -> float:
219219
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
220220
by changing the treatment variable from the control value to the treatment value. Here, we actually
221221
calculate the expected outcomes under control and treatment and take one away from the other. This
222222
allows for custom terms to be put in such as squares, inverses, products, etc.
223223
224224
:return: The estimated average treatment effect and 95% confidence intervals
225225
"""
226+
if estimator_params is None:
227+
estimator_params = {}
228+
bootstrap_size = estimator_params.get("bootstrap_size", 100)
229+
adjustment_config = estimator_params.get("adjustment_config", None)
226230
(control_outcome, control_bootstraps), (
227231
treatment_outcome,
228232
treatment_bootstraps,
@@ -245,14 +249,18 @@ def estimate_ate(self, bootstrap_size=100, adjustment_config=None) -> float:
245249

246250
return estimate, (ci_low, ci_high)
247251

248-
def estimate_risk_ratio(self, bootstrap_size=100, adjustment_config=None) -> float:
252+
def estimate_risk_ratio(self, estimator_params: dict = None) -> float:
249253
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
250254
by changing the treatment variable from the control value to the treatment value. Here, we actually
251255
calculate the expected outcomes under control and treatment and divide one by the other. This
252256
allows for custom terms to be put in such as squares, inverses, products, etc.
253257
254258
:return: The estimated risk ratio and 95% confidence intervals.
255259
"""
260+
if estimator_params is None:
261+
estimator_params = {}
262+
bootstrap_size = estimator_params.get("bootstrap_size", 100)
263+
adjustment_config = estimator_params.get("adjustment_config", None)
256264
(control_outcome, control_bootstraps), (
257265
treatment_outcome,
258266
treatment_bootstraps,

tests/testing_tests/test_estimators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,14 @@ def test_ate_adjustment(self):
124124
logistic_regression_estimator = LogisticRegressionEstimator(
125125
"length_in", 65, 55, {"large_gauge"}, "completed", df
126126
)
127-
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config={"large_gauge": 0})
127+
ate, _ = logistic_regression_estimator.estimate_ate(estimator_params={"adjustment_config": {"large_gauge": 0}})
128128
self.assertEqual(round(ate, 4), -0.3388)
129129

130130
def test_ate_invalid_adjustment(self):
131131
df = self.scarf_df.copy()
132132
logistic_regression_estimator = LogisticRegressionEstimator("length_in", 65, 55, {}, "completed", df)
133133
with self.assertRaises(ValueError):
134-
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config={"large_gauge": 0})
134+
ate, _ = logistic_regression_estimator.estimate_ate(estimator_params={"adjustment_config": {"large_gauge": 0}})
135135

136136
def test_ate_effect_modifiers(self):
137137
df = self.scarf_df.copy()

0 commit comments

Comments
 (0)