@@ -103,9 +103,7 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
103
103
test_suite_results [edge ] = results
104
104
return test_suite_results
105
105
106
- def execute_test (
107
- self , estimator : type (Estimator ), causal_test_case : CausalTestCase , estimate_type : str = "ate"
108
- ) -> CausalTestResult :
106
+ def execute_test (self , estimator : type (Estimator ), causal_test_case : CausalTestCase ) -> CausalTestResult :
109
107
"""Execute a causal test case and return the causal test result.
110
108
111
109
Test case execution proceeds with the following steps:
@@ -142,18 +140,18 @@ def execute_test(
142
140
if self ._check_positivity_violation (variables_for_positivity ):
143
141
raise ValueError ("POSITIVITY VIOLATION -- Cannot proceed." )
144
142
145
- causal_test_result = self ._return_causal_test_results (estimate_type , estimator , causal_test_case )
143
+ causal_test_result = self ._return_causal_test_results (estimator , causal_test_case )
146
144
return causal_test_result
147
145
148
- def _return_causal_test_results (self , estimate_type , estimator , causal_test_case ):
146
+ def _return_causal_test_results (self , estimator , causal_test_case ):
149
147
"""Depending on the estimator used, calculate the 95% confidence intervals and return in a causal_test_result
150
148
151
149
:param estimate_type: A string which denotes the type of estimate to return
152
150
:param estimator: An Estimator class object
153
151
:param causal_test_case: The concrete test case to be executed
154
152
:return: a CausalTestResult object containing the confidence intervals
155
153
"""
156
- if estimate_type == "cate" :
154
+ if causal_test_case . estimate_type == "cate" :
157
155
logger .debug ("calculating cate" )
158
156
if not hasattr (estimator , "estimate_cates" ):
159
157
raise NotImplementedError (f"{ estimator .__class__ } has no CATE method." )
@@ -165,7 +163,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
165
163
effect_modifier_configuration = causal_test_case .effect_modifier_configuration ,
166
164
confidence_intervals = confidence_intervals ,
167
165
)
168
- elif estimate_type == "risk_ratio" :
166
+ elif causal_test_case . estimate_type == "risk_ratio" :
169
167
logger .debug ("calculating risk_ratio" )
170
168
risk_ratio , confidence_intervals = estimator .estimate_risk_ratio ()
171
169
causal_test_result = CausalTestResult (
@@ -174,7 +172,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
174
172
effect_modifier_configuration = causal_test_case .effect_modifier_configuration ,
175
173
confidence_intervals = confidence_intervals ,
176
174
)
177
- elif estimate_type == "coefficient" :
175
+ elif causal_test_case . estimate_type == "coefficient" :
178
176
logger .debug ("calculating coefficient" )
179
177
coefficient , confidence_intervals = estimator .estimate_unit_ate ()
180
178
causal_test_result = CausalTestResult (
@@ -183,7 +181,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
183
181
effect_modifier_configuration = causal_test_case .effect_modifier_configuration ,
184
182
confidence_intervals = confidence_intervals ,
185
183
)
186
- elif estimate_type == "ate" :
184
+ elif causal_test_case . estimate_type == "ate" :
187
185
logger .debug ("calculating ate" )
188
186
ate , confidence_intervals = estimator .estimate_ate ()
189
187
causal_test_result = CausalTestResult (
@@ -194,7 +192,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
194
192
)
195
193
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)
196
194
# causal_test_result.apply_test_oracle_procedure(self.causal_test_case.expected_causal_effect)
197
- elif estimate_type == "ate_calculated" :
195
+ elif causal_test_case . estimate_type == "ate_calculated" :
198
196
logger .debug ("calculating ate" )
199
197
ate , confidence_intervals = estimator .estimate_ate_calculated ()
200
198
causal_test_result = CausalTestResult (
@@ -206,7 +204,9 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
206
204
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)
207
205
# causal_test_result.apply_test_oracle_procedure(self.causal_test_case.expected_causal_effect)
208
206
else :
209
- raise ValueError (f"Invalid estimate type { estimate_type } , expected 'ate', 'cate', or 'risk_ratio'" )
207
+ raise ValueError (
208
+ f"Invalid estimate type { causal_test_case .estimate_type } , expected 'ate', 'cate', or 'risk_ratio'"
209
+ )
210
210
return causal_test_result
211
211
212
212
def _check_positivity_violation (self , variables_list ):
0 commit comments