@@ -203,28 +203,19 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> tuple[
203
203
minimal_adjustment_set = self .causal_specification .causal_dag .identification (causal_test_case .base_test_case )
204
204
treatment_var = causal_test_case .treatment_variable
205
205
minimal_adjustment_set = minimal_adjustment_set - {treatment_var }
206
+ estimator_kwargs = {
207
+ "treatment" : treatment_var .name ,
208
+ "treatment_value" : causal_test_case .treatment_value ,
209
+ "control_value" : causal_test_case .control_value ,
210
+ "adjustment_set" : minimal_adjustment_set ,
211
+ "outcome" : causal_test_case .outcome_variable .name ,
212
+ "df" : causal_test_engine .scenario_execution_data_df ,
213
+ "effect_modifiers" : causal_test_case .effect_modifier_configuration ,
214
+ }
206
215
if "formula" in test :
207
- estimation_model = test ["estimator" ](
208
- treatment = treatment_var .name ,
209
- treatment_value = causal_test_case .treatment_value ,
210
- control_value = causal_test_case .control_value ,
211
- adjustment_set = minimal_adjustment_set ,
212
- outcome = causal_test_case .outcome_variable .name ,
213
- df = causal_test_engine .scenario_execution_data_df ,
214
- effect_modifiers = causal_test_case .effect_modifier_configuration ,
215
- formula = test ["formula" ],
216
- )
217
- else :
218
- estimation_model = test ["estimator" ](
219
- treatment = treatment_var .name ,
220
- treatment_value = causal_test_case .treatment_value ,
221
- control_value = causal_test_case .control_value ,
222
- adjustment_set = minimal_adjustment_set ,
223
- outcome = causal_test_case .outcome_variable .name ,
224
- df = causal_test_engine .scenario_execution_data_df ,
225
- effect_modifiers = causal_test_case .effect_modifier_configuration ,
226
- )
216
+ estimator_kwargs ["formula" ] = test ["formula" ]
227
217
218
+ estimation_model = test ["estimator" ](** estimator_kwargs )
228
219
return causal_test_engine , estimation_model
229
220
230
221
def _append_to_file (self , line : str , log_level : int = None ):
@@ -235,9 +226,7 @@ def _append_to_file(self, line: str, log_level: int = None):
235
226
is possible to use the inbuilt logging level variables such as logging.INFO and logging.WARNING
236
227
"""
237
228
with open (self .output_path , "a" , encoding = "utf-8" ) as f :
238
- f .write (
239
- line
240
- )
229
+ f .write (line )
241
230
if log_level :
242
231
logger .log (level = log_level , msg = line )
243
232
0 commit comments