@@ -81,7 +81,6 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
81
81
82
82
estimators = test_suite [edge ]["estimators" ]
83
83
tests = test_suite [edge ]["tests" ]
84
- estimate_type = test_suite [edge ]["estimate_type" ]
85
84
results = {}
86
85
for estimator_class in estimators :
87
86
causal_test_results = []
@@ -96,16 +95,14 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
96
95
)
97
96
if estimator .df is None :
98
97
estimator .df = self .scenario_execution_data_df
99
- causal_test_result = self ._return_causal_test_results (estimate_type , estimator , test )
98
+ causal_test_result = self ._return_causal_test_results (estimator , test )
100
99
causal_test_results .append (causal_test_result )
101
100
102
101
results [estimator_class .__name__ ] = causal_test_results
103
102
test_suite_results [edge ] = results
104
103
return test_suite_results
105
104
106
- def execute_test (
107
- self , estimator : type (Estimator ), causal_test_case : CausalTestCase , estimate_type : str = "ate"
108
- ) -> CausalTestResult :
105
+ def execute_test (self , estimator : type (Estimator ), causal_test_case : CausalTestCase ) -> CausalTestResult :
109
106
"""Execute a causal test case and return the causal test result.
110
107
111
108
Test case execution proceeds with the following steps:
@@ -120,7 +117,6 @@ def execute_test(
120
117
121
118
:param estimator: A reference to an Estimator class.
122
119
:param causal_test_case: The CausalTestCase object to be tested
123
- :param estimate_type: A string which denotes the type of estimate to return, ATE or CATE.
124
120
:return causal_test_result: A CausalTestResult for the executed causal test case.
125
121
"""
126
122
if self .scenario_execution_data_df .empty :
@@ -142,18 +138,17 @@ def execute_test(
142
138
if self ._check_positivity_violation (variables_for_positivity ):
143
139
raise ValueError ("POSITIVITY VIOLATION -- Cannot proceed." )
144
140
145
- causal_test_result = self ._return_causal_test_results (estimate_type , estimator , causal_test_case )
141
+ causal_test_result = self ._return_causal_test_results (estimator , causal_test_case )
146
142
return causal_test_result
147
143
148
- def _return_causal_test_results (self , estimate_type , estimator , causal_test_case ):
144
+ def _return_causal_test_results (self , estimator , causal_test_case ):
149
145
"""Depending on the estimator used, calculate the 95% confidence intervals and return in a causal_test_result
150
146
151
- :param estimate_type: A string which denotes the type of estimate to return
152
147
:param estimator: An Estimator class object
153
148
:param causal_test_case: The concrete test case to be executed
154
149
:return: a CausalTestResult object containing the confidence intervals
155
150
"""
156
- if estimate_type == "cate" :
151
+ if causal_test_case . estimate_type == "cate" :
157
152
logger .debug ("calculating cate" )
158
153
if not hasattr (estimator , "estimate_cates" ):
159
154
raise NotImplementedError (f"{ estimator .__class__ } has no CATE method." )
@@ -165,7 +160,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
165
160
effect_modifier_configuration = causal_test_case .effect_modifier_configuration ,
166
161
confidence_intervals = confidence_intervals ,
167
162
)
168
- elif estimate_type == "risk_ratio" :
163
+ elif causal_test_case . estimate_type == "risk_ratio" :
169
164
logger .debug ("calculating risk_ratio" )
170
165
risk_ratio , confidence_intervals = estimator .estimate_risk_ratio ()
171
166
causal_test_result = CausalTestResult (
@@ -174,7 +169,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
174
169
effect_modifier_configuration = causal_test_case .effect_modifier_configuration ,
175
170
confidence_intervals = confidence_intervals ,
176
171
)
177
- elif estimate_type == "coefficient" :
172
+ elif causal_test_case . estimate_type == "coefficient" :
178
173
logger .debug ("calculating coefficient" )
179
174
coefficient , confidence_intervals = estimator .estimate_unit_ate ()
180
175
causal_test_result = CausalTestResult (
@@ -183,7 +178,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
183
178
effect_modifier_configuration = causal_test_case .effect_modifier_configuration ,
184
179
confidence_intervals = confidence_intervals ,
185
180
)
186
- elif estimate_type == "ate" :
181
+ elif causal_test_case . estimate_type == "ate" :
187
182
logger .debug ("calculating ate" )
188
183
ate , confidence_intervals = estimator .estimate_ate ()
189
184
causal_test_result = CausalTestResult (
@@ -194,7 +189,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
194
189
)
195
190
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)
196
191
# causal_test_result.apply_test_oracle_procedure(self.causal_test_case.expected_causal_effect)
197
- elif estimate_type == "ate_calculated" :
192
+ elif causal_test_case . estimate_type == "ate_calculated" :
198
193
logger .debug ("calculating ate" )
199
194
ate , confidence_intervals = estimator .estimate_ate_calculated ()
200
195
causal_test_result = CausalTestResult (
@@ -206,7 +201,9 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
206
201
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)
207
202
# causal_test_result.apply_test_oracle_procedure(self.causal_test_case.expected_causal_effect)
208
203
else :
209
- raise ValueError (f"Invalid estimate type { estimate_type } , expected 'ate', 'cate', or 'risk_ratio'" )
204
+ raise ValueError (
205
+ f"Invalid estimate type { causal_test_case .estimate_type } , expected 'ate', 'cate', or 'risk_ratio'"
206
+ )
210
207
return causal_test_result
211
208
212
209
def _check_positivity_violation (self , variables_list ):
0 commit comments