Skip to content

Commit 45cb719

Browse files
committed
Updated docs, tests and examples
1 parent 3d3d3b4 commit 45cb719

File tree

6 files changed

+15
-18
lines changed

6 files changed

+15
-18
lines changed

docs/source/usage.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ various information. Here, we simply assert that the observed result is (on aver
107107
108108
causal_test_result = causal_test_engine.execute_test(
109109
estimator = estimation_model,
110-
causal_test_case = causal_test_case,
111-
estimate_type = "ate")
110+
causal_test_case = causal_test_case)
112111
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
113112
assert test_passes, "Expected to see a positive change in y."
114113

examples/covasim_/doubling_beta/example_beta.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def doubling_beta_CATE_on_csv(
6565
)
6666

6767
# Add squared terms for beta, since it has a quadratic relationship with cumulative infections
68-
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, "ate")
68+
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case)
6969

7070
# Repeat for association estimate (no adjustment)
7171
no_adjustment_linear_regression_estimator = LinearRegressionEstimator(
@@ -78,7 +78,7 @@ def doubling_beta_CATE_on_csv(
7878
formula="cum_infections ~ beta + np.power(beta, 2)",
7979
)
8080
association_test_result = causal_test_engine.execute_test(
81-
no_adjustment_linear_regression_estimator, causal_test_case, "ate"
81+
no_adjustment_linear_regression_estimator, causal_test_case
8282
)
8383

8484
# Store results for plotting
@@ -110,7 +110,7 @@ def doubling_beta_CATE_on_csv(
110110
formula="cum_infections ~ beta + np.power(beta, 2) + avg_age + contacts",
111111
)
112112
counterfactual_causal_test_result = causal_test_engine.execute_test(
113-
linear_regression_estimator, causal_test_case, "ate"
113+
linear_regression_estimator, causal_test_case
114114
)
115115
results_dict["counterfactual"] = {
116116
"ate": counterfactual_causal_test_result.test_value.value,

examples/covasim_/vaccinating_elderly/example_vaccine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_experimental_vaccinate_elderly(runs_per_test_per_config: int = 30, verb
9999
)
100100

101101
# 10. Execute test and save results in dict
102-
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, "ate")
102+
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case)
103103
if verbose:
104104
logging.info("Causation:\n%s", causal_test_result)
105105
results_dict[outcome_variable.name]["ate"] = causal_test_result.test_value.value

examples/lr91/example_max_conductances.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def effects_on_APD90(observational_data_path, treatment_var, control_val, treatm
148148
)
149149

150150
# 10. Run the causal test and print results
151-
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, "ate")
151+
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case)
152152
logger.info("%s", causal_test_result)
153153
return causal_test_result.test_value.value, causal_test_result.confidence_intervals
154154

examples/poisson-line-process/example_poisson_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def causal_test_intensity_num_shapes(
123123
)
124124

125125
# 10. Execute the test
126-
causal_test_result = causal_test_engine.execute_test(estimator, causal_test_case, causal_test_case.estimate_type)
126+
causal_test_result = causal_test_engine.execute_test(estimator, causal_test_case)
127127

128128
return causal_test_result
129129

tests/testing_tests/test_causal_test_engine.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,8 @@ def test_execute_test_observational_linear_regression_estimator_coefficient(self
189189
"A",
190190
self.causal_test_engine.scenario_execution_data_df,
191191
)
192-
causal_test_result = self.causal_test_engine.execute_test(
193-
estimation_model, self.causal_test_case, estimate_type="coefficient"
194-
)
192+
self.causal_test_case.estimate_type = "coefficient"
193+
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
195194
self.assertEqual(int(causal_test_result.test_value.value), 0)
196195

197196
def test_execute_test_observational_linear_regression_estimator_risk_ratio(self):
@@ -205,9 +204,8 @@ def test_execute_test_observational_linear_regression_estimator_risk_ratio(self)
205204
"A",
206205
self.causal_test_engine.scenario_execution_data_df,
207206
)
208-
causal_test_result = self.causal_test_engine.execute_test(
209-
estimation_model, self.causal_test_case, estimate_type="risk_ratio"
210-
)
207+
self.causal_test_case.estimate_type = "risk_ratio"
208+
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
211209
self.assertEqual(int(causal_test_result.test_value.value), 0)
212210

213211
def test_invalid_estimate_type(self):
@@ -221,8 +219,9 @@ def test_invalid_estimate_type(self):
221219
"A",
222220
self.causal_test_engine.scenario_execution_data_df,
223221
)
222+
self.causal_test_case.estimate_type = "invalid"
224223
with self.assertRaises(ValueError):
225-
self.causal_test_engine.execute_test(estimation_model, self.causal_test_case, estimate_type="invalid")
224+
self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
226225

227226
def test_execute_test_observational_linear_regression_estimator_squared_term(self):
228227
"""Check that executing the causal test case returns the correct results for dummy data with a squared term
@@ -258,9 +257,8 @@ def test_execute_observational_causal_forest_estimator_cates(self):
258257
self.causal_test_engine.scenario_execution_data_df,
259258
effect_modifiers={"M": None},
260259
)
261-
causal_test_result = self.causal_test_engine.execute_test(
262-
estimation_model, self.causal_test_case, estimate_type="cate"
263-
)
260+
self.causal_test_case.estimate_type = "cate"
261+
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
264262
causal_test_result = causal_test_result.test_value.value
265263
# Check that each effect modifier's strata has a greater ATE than the last (ascending order)
266264
causal_test_result_m1 = causal_test_result.loc[causal_test_result["M"] == 1]

0 commit comments

Comments
 (0)