Skip to content

Commit bf0c3c5

Browse files
committed
Estimates now taken from causal test case when executing test
1 parent 9310ebe commit bf0c3c5

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

causal_testing/testing/causal_test_engine.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,7 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
103103
test_suite_results[edge] = results
104104
return test_suite_results
105105

106-
def execute_test(
107-
self, estimator: type(Estimator), causal_test_case: CausalTestCase, estimate_type: str = "ate"
108-
) -> CausalTestResult:
106+
def execute_test(self, estimator: type(Estimator), causal_test_case: CausalTestCase) -> CausalTestResult:
109107
"""Execute a causal test case and return the causal test result.
110108
111109
Test case execution proceeds with the following steps:
@@ -142,18 +140,18 @@ def execute_test(
142140
if self._check_positivity_violation(variables_for_positivity):
143141
raise ValueError("POSITIVITY VIOLATION -- Cannot proceed.")
144142

145-
causal_test_result = self._return_causal_test_results(estimate_type, estimator, causal_test_case)
143+
causal_test_result = self._return_causal_test_results(estimator, causal_test_case)
146144
return causal_test_result
147145

148-
def _return_causal_test_results(self, estimate_type, estimator, causal_test_case):
146+
def _return_causal_test_results(self, estimator, causal_test_case):
149147
"""Depending on the estimator used, calculate the 95% confidence intervals and return in a causal_test_result
150148
151149
:param estimate_type: A string which denotes the type of estimate to return
152150
:param estimator: An Estimator class object
153151
:param causal_test_case: The concrete test case to be executed
154152
:return: a CausalTestResult object containing the confidence intervals
155153
"""
156-
if estimate_type == "cate":
154+
if causal_test_case.estimate_type == "cate":
157155
logger.debug("calculating cate")
158156
if not hasattr(estimator, "estimate_cates"):
159157
raise NotImplementedError(f"{estimator.__class__} has no CATE method.")
@@ -165,7 +163,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
165163
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
166164
confidence_intervals=confidence_intervals,
167165
)
168-
elif estimate_type == "risk_ratio":
166+
elif causal_test_case.estimate_type == "risk_ratio":
169167
logger.debug("calculating risk_ratio")
170168
risk_ratio, confidence_intervals = estimator.estimate_risk_ratio()
171169
causal_test_result = CausalTestResult(
@@ -174,7 +172,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
174172
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
175173
confidence_intervals=confidence_intervals,
176174
)
177-
elif estimate_type == "coefficient":
175+
elif causal_test_case.estimate_type == "coefficient":
178176
logger.debug("calculating coefficient")
179177
coefficient, confidence_intervals = estimator.estimate_unit_ate()
180178
causal_test_result = CausalTestResult(
@@ -183,7 +181,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
183181
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
184182
confidence_intervals=confidence_intervals,
185183
)
186-
elif estimate_type == "ate":
184+
elif causal_test_case.estimate_type == "ate":
187185
logger.debug("calculating ate")
188186
ate, confidence_intervals = estimator.estimate_ate()
189187
causal_test_result = CausalTestResult(
@@ -194,7 +192,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
194192
)
195193
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)
196194
# causal_test_result.apply_test_oracle_procedure(self.causal_test_case.expected_causal_effect)
197-
elif estimate_type == "ate_calculated":
195+
elif causal_test_case.estimate_type == "ate_calculated":
198196
logger.debug("calculating ate")
199197
ate, confidence_intervals = estimator.estimate_ate_calculated()
200198
causal_test_result = CausalTestResult(
@@ -206,7 +204,9 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
206204
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)
207205
# causal_test_result.apply_test_oracle_procedure(self.causal_test_case.expected_causal_effect)
208206
else:
209-
raise ValueError(f"Invalid estimate type {estimate_type}, expected 'ate', 'cate', or 'risk_ratio'")
207+
raise ValueError(
208+
f"Invalid estimate type {causal_test_case.estimate_type}, expected 'ate', 'cate', or 'risk_ratio'"
209+
)
210210
return causal_test_result
211211

212212
def _check_positivity_violation(self, variables_list):

0 commit comments

Comments
 (0)