Skip to content

Commit 90429ce

Browse files
Remove set_variables and pass in Scenario
1 parent d8ee91c commit 90429ce

File tree

1 file changed

+17
-28
lines changed

1 file changed

+17
-28
lines changed

causal_testing/json_front/json_class.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
logger = logging.getLogger(__name__)
2727

2828

29-
class JsonUtility(ABC):
29+
class JsonUtility:
3030
"""
3131
The JsonUtility Class provides the functionality to use structured JSON to setup and run causal tests on the
3232
CausalTestingFramework.
@@ -48,7 +48,7 @@ def __init__(self, log_path):
4848
self.variables = None
4949
self.data = []
5050
self.test_plan = None
51-
self.modelling_scenario = None
51+
self.scenario = None
5252
self.causal_specification = None
5353
self.setup_logger(log_path)
5454

@@ -61,36 +61,27 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: str):
6161
"""
6262
self.paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths)
6363

64-
def set_variables(self, inputs: list[dict], outputs: list[dict], metas: list[dict]):
65-
"""Populate the Causal Variables
66-
:param inputs:
67-
:param outputs:
68-
:param metas:
69-
"""
70-
71-
self.variables = CausalVariables(inputs=inputs, outputs=outputs, metas=metas)
72-
73-
def setup(self):
64+
def setup(self, scenario: Scenario):
7465
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
75-
self.modelling_scenario = Scenario(self.variables.inputs + self.variables.outputs + self.variables.metas, None)
76-
self.modelling_scenario.setup_treatment_variables()
66+
self.scenario = scenario
67+
self.scenario.setup_treatment_variables()
7768
self.causal_specification = CausalSpecification(
78-
scenario=self.modelling_scenario, causal_dag=CausalDAG(self.paths.dag_path)
69+
scenario=self.scenario, causal_dag=CausalDAG(self.paths.dag_path)
7970
)
8071
self._json_parse()
8172
self._populate_metas()
8273

8374
def _create_abstract_test_case(self, test, mutates, effects):
8475
assert len(test["mutations"]) == 1
8576
abstract_test = AbstractCausalTestCase(
86-
scenario=self.modelling_scenario,
77+
scenario=self.scenario,
8778
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
88-
treatment_variable=next(self.modelling_scenario.variables[v] for v in test["mutations"]),
79+
treatment_variable=next(self.scenario.variables[v] for v in test["mutations"]),
8980
expected_causal_effect={
90-
self.modelling_scenario.variables[variable]: effects[effect]
81+
self.scenario.variables[variable]: effects[effect]
9182
for variable, effect in test["expectedEffect"].items()
9283
},
93-
effect_modifiers={self.modelling_scenario.variables[v] for v in test["effect_modifiers"]}
84+
effect_modifiers={self.scenario.variables[v] for v in test["effect_modifiers"]}
9485
if "effect_modifiers" in test
9586
else {},
9687
estimate_type=test["estimate_type"],
@@ -141,10 +132,9 @@ def _populate_metas(self):
141132
"""
142133
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
143134
"""
144-
for meta in self.variables.metas:
135+
for meta in self.scenario.variables_of_type(Meta):
145136
meta.populate(self.data)
146-
147-
for var in self.variables.metas + self.variables.outputs:
137+
for var in self.scenario.variables_of_type(Meta).union(self.scenario.variables_of_type(Output)):
148138
if not var.distribution:
149139
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
150140
fitter.fit()
@@ -195,7 +185,7 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
195185
- estimation_model - Estimator instance for the test being run
196186
"""
197187

198-
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data)
188+
data_collector = ObservationalDataCollector(self.scenario, self.data)
199189
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
200190

201191
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
@@ -289,17 +279,16 @@ def __init__(self, json_path: str, dag_path: str, data_paths: str):
289279
self.data_paths = [Path(path) for path in data_paths]
290280

291281

292-
@dataclass()
293282
class CausalVariables:
294283
"""
295284
A dataclass that converts
296285
"""
297286

298-
inputs: list[Input]
299-
outputs: list[Output]
300-
metas: list[Meta]
301-
302287
def __init__(self, inputs: list[dict], outputs: list[dict], metas: list[dict]):
303288
self.inputs = [Input(**i) for i in inputs]
304289
self.outputs = [Output(**o) for o in outputs]
305290
self.metas = [Meta(**m) for m in metas] if metas else []
291+
292+
def __iter__(self):
293+
for var in self.inputs + self.outputs + self.metas:
294+
yield var

0 commit comments

Comments
 (0)