Skip to content

Commit 28aa04c

Browse files
Merge pull request #181 from CITCOM-project/json_fitter
Fixed fitter bug
2 parents 8fafe9d + 48b60d9 commit 28aa04c

File tree

2 files changed

+5
-10
lines changed

2 files changed

+5
-10
lines changed

causal_testing/json_front/json_class.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def setup(self, scenario: Scenario):
7878
def _create_abstract_test_case(self, test, mutates, effects):
7979
assert len(test["mutations"]) == 1
8080
treatment_var = next(self.scenario.variables[v] for v in test["mutations"])
81+
8182
if not treatment_var.distribution:
8283
fitter = Fitter(self.data[treatment_var.name], distributions=get_common_distributions())
8384
fitter.fit()
@@ -223,15 +224,9 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Ma
223224
"""
224225
failed = False
225226

226-
for var in self.scenario.variables_of_type(Meta).union(self.scenario.variables_of_type(Output)):
227-
if not var.distribution:
228-
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
229-
fitter.fit()
230-
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
231-
var.distribution = getattr(scipy.stats, dist)(**params)
232-
self._append_to_file(var.name + f" {dist}({params})", logging.INFO)
233-
234-
causal_test_engine, estimation_model = self._setup_test(causal_test_case, test)
227+
causal_test_engine, estimation_model = self._setup_test(
228+
causal_test_case, test, test["conditions"] if "conditions" in test else None
229+
)
235230
causal_test_result = causal_test_engine.execute_test(
236231
estimation_model, causal_test_case, estimate_type=causal_test_case.estimate_type
237232
)

tests/json_front_tests/test_json_class.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def test_run_concrete_json_testcase(self):
232232
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
233233

234234
self.json_class.run_json_tests(effects=effects, estimators=estimators, f_flag=False)
235-
with open("temp_out.txt", 'r') as reader:
235+
with open("temp_out.txt", "r") as reader:
236236
temp_out = reader.readlines()
237237
self.assertIn("failed", temp_out[-1])
238238

0 commit comments

Comments
 (0)