Skip to content

Commit a696899

Browse files
Address Github comments + linting
1 parent 1b57e2e commit a696899

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

causal_testing/json_front/json_class.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ class JsonUtility:
4848
def __init__(self, output_path: str, output_overwrite: bool = False):
4949
self.input_paths = None
5050
self.variables = {"inputs": {}, "outputs": {}, "metas": {}}
51-
self.data = []
5251
self.test_plan = None
5352
self.scenario = None
5453
self.causal_specification = None
@@ -69,6 +68,7 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None)
6968

7069
def setup(self, scenario: Scenario):
7170
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
71+
data = []
7272
self.scenario = scenario
7373
self._get_scenario_variables()
7474
self.scenario.setup_treatment_variables()
@@ -80,22 +80,23 @@ def setup(self, scenario: Scenario):
8080
self.test_plan = json.load(f)
8181
# Populate the data
8282
if self.input_paths.data_paths:
83-
self.data = pd.concat([pd.read_csv(data_file, header=0) for data_file in self.input_paths.data_paths])
84-
if len(self.data) == 0:
83+
data = pd.concat([pd.read_csv(data_file, header=0) for data_file in self.input_paths.data_paths])
84+
if len(data) == 0:
8585
raise ValueError(
8686
"No data found, either provide a path to a file containing data or manually populate the .data "
8787
"attribute with a dataframe before calling .setup()"
8888
)
89-
self._populate_metas()
9089
self.data_collector = ObservationalDataCollector(
91-
self.scenario, self.data)
90+
self.scenario, data)
91+
self._populate_metas()
92+
9293

9394
def _create_abstract_test_case(self, test, mutates, effects):
9495
assert len(test["mutations"]) == 1
9596
treatment_var = next(self.scenario.variables[v] for v in test["mutations"])
9697

9798
if not treatment_var.distribution:
98-
fitter = Fitter(self.data[treatment_var.name], distributions=get_common_distributions())
99+
fitter = Fitter(self.data_collector.data[treatment_var.name], distributions=get_common_distributions())
99100
fitter.fit()
100101
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
101102
treatment_var.distribution = getattr(scipy.stats, dist)(**params)
@@ -247,10 +248,10 @@ def _populate_metas(self):
247248
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
248249
"""
249250
for meta in self.scenario.variables_of_type(Meta):
250-
meta.populate(self.data)
251+
meta.populate(self.data_collector.data)
251252

252253
def _execute_test_case(
253-
self, causal_test_case: CausalTestCase, test: Iterable[Mapping],
254+
self, causal_test_case: CausalTestCase, test: Mapping,
254255
f_flag: bool
255256
) -> (bool, CausalTestResult):
256257
"""Executes a singular test case, prints the results and returns the test case result
@@ -309,7 +310,6 @@ def _setup_test(
309310
"control_value": causal_test_case.control_value,
310311
"adjustment_set": minimal_adjustment_set,
311312
"outcome": causal_test_case.outcome_variable.name,
312-
"df": self.data_collector.collect_data(),
313313
"effect_modifiers": causal_test_case.effect_modifier_configuration,
314314
"alpha": test["alpha"] if "alpha" in test else 0.05,
315315
}

0 commit comments

Comments
 (0)