@@ -89,10 +89,11 @@ def setup(self):
89
89
self ._populate_metas ()
90
90
91
91
def _create_abstract_test_case (self , test , mutates , effects ):
92
+
92
93
abstract_test = AbstractCausalTestCase (
93
94
scenario = self .modelling_scenario ,
94
95
intervention_constraints = [mutates [v ](k ) for k , v in test ["mutations" ].items ()],
95
- treatment_variables = { self .modelling_scenario .variables [v ] for v in test ["mutations" ]} ,
96
+ treatment_variables = self .modelling_scenario .variables [next ( iter ( test ["mutations" ]))] ,
96
97
expected_causal_effect = {
97
98
self .modelling_scenario .variables [variable ]: effects [effect ]
98
99
for variable , effect in test ["expectedEffect" ].items ()
@@ -121,7 +122,7 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
121
122
concrete_tests , dummy = abstract_test .generate_concrete_tests (5 , 0.05 )
122
123
logger .info ("Executing test: %s" , test ["name" ])
123
124
logger .info (abstract_test )
124
- logger .info ([( v . name , v . distribution ) for v in abstract_test .treatment_variables ])
125
+ logger .info ([abstract_test . treatment_variables . name , abstract_test .treatment_variables . distribution ])
125
126
logger .info ("Number of concrete tests for test case: %s" , str (len (concrete_tests )))
126
127
failures = self ._execute_tests (concrete_tests , estimators , test , f_flag )
127
128
logger .info ("%s/%s failed" , failures , len (concrete_tests ))
@@ -204,15 +205,15 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
204
205
"""
205
206
data_collector = ObservationalDataCollector (self .modelling_scenario , self .data_path )
206
207
causal_test_engine = CausalTestEngine (self .causal_specification , data_collector , index_col = 0 )
207
- causal_test_engine . identification (causal_test_case )
208
+ minimal_adjustment_set = self . causal_specification . causal_dag . identification (causal_test_case . base_causal_test )
208
209
treatment_vars = list (causal_test_case .treatment_input_configuration )
209
- minimal_adjustment_set = causal_test_engine . minimal_adjustment_set - {v .name for v in treatment_vars }
210
+ minimal_adjustment_set = minimal_adjustment_set - {v .name for v in treatment_vars }
210
211
estimation_model = estimator (
211
212
(list (treatment_vars )[0 ].name ,),
212
213
[causal_test_case .treatment_input_configuration [v ] for v in treatment_vars ][0 ],
213
214
[causal_test_case .control_input_configuration [v ] for v in treatment_vars ][0 ],
214
215
minimal_adjustment_set ,
215
- (list ( causal_test_case .outcome_variables )[ 0 ] .name ,),
216
+ (causal_test_case .outcome_variable .name ,),
216
217
causal_test_engine .scenario_execution_data_df ,
217
218
effect_modifiers = causal_test_case .effect_modifier_configuration ,
218
219
)
0 commit comments