@@ -85,6 +85,13 @@ def _create_abstract_test_case(self, test, mutates, effects):
85
85
treatment_var .distribution = getattr (scipy .stats , dist )(** params )
86
86
self ._append_to_file (treatment_var .name + f" { dist } ({ params } )" , logging .INFO )
87
87
88
+ if not treatment_var .distribution :
89
+ fitter = Fitter (self .data [treatment_var .name ], distributions = get_common_distributions ())
90
+ fitter .fit ()
91
+ (dist , params ) = list (fitter .get_best (method = "sumsquare_error" ).items ())[0 ]
92
+ treatment_var .distribution = getattr (scipy .stats , dist )(** params )
93
+ self ._append_to_file (treatment_var .name + f" { dist } ({ params } )" , logging .INFO )
94
+
88
95
abstract_test = AbstractCausalTestCase (
89
96
scenario = self .scenario ,
90
97
intervention_constraints = [mutates [v ](k ) for k , v in test ["mutations" ].items ()],
@@ -223,15 +230,9 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Ma
223
230
"""
224
231
failed = False
225
232
226
- for var in self .scenario .variables_of_type (Meta ).union (self .scenario .variables_of_type (Output )):
227
- if not var .distribution :
228
- fitter = Fitter (self .data [var .name ], distributions = get_common_distributions ())
229
- fitter .fit ()
230
- (dist , params ) = list (fitter .get_best (method = "sumsquare_error" ).items ())[0 ]
231
- var .distribution = getattr (scipy .stats , dist )(** params )
232
- self ._append_to_file (var .name + f" { dist } ({ params } )" , logging .INFO )
233
-
234
- causal_test_engine , estimation_model = self ._setup_test (causal_test_case , test )
233
+ causal_test_engine , estimation_model = self ._setup_test (
234
+ causal_test_case , test , test ["conditions" ] if "conditions" in test else None
235
+ )
235
236
causal_test_result = causal_test_engine .execute_test (
236
237
estimation_model , causal_test_case , estimate_type = causal_test_case .estimate_type
237
238
)
0 commit comments