Skip to content

Commit 9bc7e25

Browse files
committed
Fixed fitter bug
1 parent 8fafe9d commit 9bc7e25

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

causal_testing/json_front/json_class.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ def _create_abstract_test_case(self, test, mutates, effects):
8585
treatment_var.distribution = getattr(scipy.stats, dist)(**params)
8686
self._append_to_file(treatment_var.name + f" {dist}({params})", logging.INFO)
8787

88+
if not treatment_var.distribution:
89+
fitter = Fitter(self.data[treatment_var.name], distributions=get_common_distributions())
90+
fitter.fit()
91+
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
92+
treatment_var.distribution = getattr(scipy.stats, dist)(**params)
93+
self._append_to_file(treatment_var.name + f" {dist}({params})", logging.INFO)
94+
8895
abstract_test = AbstractCausalTestCase(
8996
scenario=self.scenario,
9097
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
@@ -223,15 +230,9 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Ma
223230
"""
224231
failed = False
225232

226-
for var in self.scenario.variables_of_type(Meta).union(self.scenario.variables_of_type(Output)):
227-
if not var.distribution:
228-
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
229-
fitter.fit()
230-
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
231-
var.distribution = getattr(scipy.stats, dist)(**params)
232-
self._append_to_file(var.name + f" {dist}({params})", logging.INFO)
233-
234-
causal_test_engine, estimation_model = self._setup_test(causal_test_case, test)
233+
causal_test_engine, estimation_model = self._setup_test(
234+
causal_test_case, test, test["conditions"] if "conditions" in test else None
235+
)
235236
causal_test_result = causal_test_engine.execute_test(
236237
estimation_model, causal_test_case, estimate_type=causal_test_case.estimate_type
237238
)

tests/json_front_tests/test_json_class.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def test_run_concrete_json_testcase(self):
232232
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
233233

234234
self.json_class.run_json_tests(effects=effects, estimators=estimators, f_flag=False)
235-
with open("temp_out.txt", 'r') as reader:
235+
with open("temp_out.txt", "r") as reader:
236236
temp_out = reader.readlines()
237237
self.assertIn("failed", temp_out[-1])
238238

0 commit comments

Comments
 (0)