Skip to content

Commit b227234

Browse files
Group local instance attributes to dataclasses
1 parent 9596056 commit b227234

File tree

2 files changed

+79
-36
lines changed

2 files changed

+79
-36
lines changed

causal_testing/json_front/json_class.py

Lines changed: 68 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""This module contains the JsonUtility class, details of using this class can be found here:
22
https://causal-testing-framework.readthedocs.io/en/latest/json_front_end.html"""
3+
34
import argparse
45
import json
56
import logging
67

78
from abc import ABC
9+
from dataclasses import dataclass
810
from pathlib import Path
911

1012
import pandas as pd
@@ -42,49 +44,45 @@ class JsonUtility(ABC):
4244
"""
4345

4446
def __init__(self, log_path):
45-
self.json_path = None
46-
self.dag_path = None
47-
self.data_path = None
48-
self.inputs = None
49-
self.outputs = None
50-
self.metas = None
47+
self.paths = None
48+
self.variables = None
5149
self.data = None
5250
self.test_plan = None
5351
self.modelling_scenario = None
5452
self.causal_specification = None
5553
self.setup_logger(log_path)
5654

57-
def set_path(self, json_path: str, dag_path: str, data_path: str):
55+
def set_paths(self, json_path: str, dag_path: str, data_path: str):
5856
"""
5957
Takes a path of the directory containing all scenario specific files and creates individual paths for each file
6058
:param json_path: string path representation to .json file containing test specifications
6159
:param dag_path: string path representation to the .dot file containing the Causal DAG
6260
:param data_path: string path representation to the data file
63-
:returns:
64-
- json_path -
65-
- dag_path -
66-
- data_path -
6761
"""
68-
self.json_path = Path(json_path)
69-
self.dag_path = Path(dag_path)
70-
self.data_path = Path(data_path)
62+
self.paths = JsonClassPaths(
63+
json_path=json_path,
64+
dag_path=dag_path,
65+
data_path=data_path
66+
)
7167

72-
def set_variables(self, inputs: dict, outputs: dict, metas: dict):
68+
def set_variables(self, inputs: list[dict], outputs: list[dict], metas: list[dict]):
7369
"""Populate the Causal Variables
7470
:param inputs:
7571
:param outputs:
7672
:param metas:
7773
"""
78-
self.inputs = [Input(i["name"], i["type"], i["distribution"]) for i in inputs]
79-
self.outputs = [Output(i["name"], i["type"], i.get("distribution", None)) for i in outputs]
80-
self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else []
8174

75+
self.variables = CausalVariables(
76+
inputs=inputs,
77+
outputs=outputs,
78+
metas=metas
79+
)
8280
def setup(self):
8381
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
84-
self.modelling_scenario = Scenario(self.inputs + self.outputs + self.metas, None)
82+
self.modelling_scenario = Scenario(self.variables.inputs + self.variables.outputs + self.variables.metas, None)
8583
self.modelling_scenario.setup_treatment_variables()
8684
self.causal_specification = CausalSpecification(
87-
scenario=self.modelling_scenario, causal_dag=CausalDAG(self.dag_path)
85+
scenario=self.modelling_scenario, causal_dag=CausalDAG(self.paths.dag_path)
8886
)
8987
self._json_parse()
9088
self._populate_metas()
@@ -139,20 +137,20 @@ def _execute_tests(self, concrete_tests, estimators, test, f_flag):
139137

140138
def _json_parse(self):
141139
"""Parse a JSON input file into inputs, outputs, metas and a test plan"""
142-
with open(self.json_path, encoding="utf-8") as f:
140+
with open(self.paths.json_path, encoding="utf-8") as f:
143141
self.test_plan = json.load(f)
144142

145-
self.data = pd.read_csv(self.data_path)
143+
self.data = pd.read_csv(self.paths.data_path)
146144

147145
def _populate_metas(self):
148146
"""
149147
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
150148
"""
151149

152-
for meta in self.metas:
150+
for meta in self.variables.metas:
153151
meta.populate(self.data)
154152

155-
for var in self.metas + self.outputs:
153+
for var in self.variables.metas + self.variables.outputs:
156154
if not var.distribution:
157155
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
158156
fitter.fit()
@@ -202,7 +200,7 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
202200
- causal_test_engine - Test Engine instance for the test being run
203201
- estimation_model - Estimator instance for the test being run
204202
"""
205-
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data_path)
203+
data_collector = ObservationalDataCollector(self.modelling_scenario, self.paths.data_path)
206204
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
207205
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
208206
treatment_var = causal_test_case.treatment_variable
@@ -273,3 +271,48 @@ def get_args(test_args=None) -> argparse.Namespace:
273271
required=True,
274272
)
275273
return parser.parse_args(test_args)
274+
275+
276+
@dataclass
277+
class JsonClassPaths:
278+
"""
279+
A dataclass that converts strings of paths to Path objects for use in the JsonUtility class
280+
:param json_path: string path representation to .json file containing test specifications
281+
:param dag_path: string path representation to the .dot file containing the Causal DAG
282+
:param data_path: string path representation to the data file
283+
"""
284+
285+
json_path: Path
286+
dag_path: Path
287+
data_path: Path
288+
289+
def __init__(
290+
self,
291+
json_path: str,
292+
dag_path: str,
293+
data_path: str
294+
):
295+
self.json_path = Path(json_path)
296+
self.dag_path = Path(dag_path)
297+
self.data_path = Path(data_path)
298+
299+
300+
@dataclass()
301+
class CausalVariables:
302+
"""
303+
A dataclass that converts
304+
"""
305+
306+
inputs: list[Input]
307+
outputs: list[Output]
308+
metas: list[Meta]
309+
310+
def __init__(
311+
self,
312+
inputs: list[dict],
313+
outputs: list[dict],
314+
metas: list[dict]
315+
):
316+
self.inputs = [Input(i["name"], i["type"], i["distribution"]) for i in inputs]
317+
self.outputs = [Output(i["name"], i["type"], i.get("distribution", None)) for i in outputs]
318+
self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else []

tests/json_front_tests/test_json_class.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,29 +35,29 @@ def setUp(self) -> None:
3535
self.output_dict_list = [{"name": "test_output", "type": float}]
3636
self.meta_dict_list = [{"name": "test_meta", "type": float, "populate": populate_example}]
3737
self.json_class.set_variables(self.input_dict_list, self.output_dict_list, None)
38-
self.json_class.set_path(self.json_path, self.dag_path, self.data_path)
38+
self.json_class.set_paths(self.json_path, self.dag_path, self.data_path)
3939

4040
def test_setting_paths(self):
41-
self.assertEqual(self.json_class.json_path, Path(self.json_path))
42-
self.assertEqual(self.json_class.dag_path, Path(self.dag_path))
43-
self.assertEqual(self.json_class.data_path, Path(self.data_path))
41+
self.assertEqual(self.json_class.paths.json_path, Path(self.json_path))
42+
self.assertEqual(self.json_class.paths.dag_path, Path(self.dag_path))
43+
self.assertEqual(self.json_class.paths.data_path, Path(self.data_path))
4444

4545
def test_set_inputs(self):
4646
ctf_input = [Input("test_input", float, self.example_distribution)]
47-
self.assertEqual(ctf_input[0].name, self.json_class.inputs[0].name)
48-
self.assertEqual(ctf_input[0].datatype, self.json_class.inputs[0].datatype)
49-
self.assertEqual(ctf_input[0].distribution, self.json_class.inputs[0].distribution)
47+
self.assertEqual(ctf_input[0].name, self.json_class.variables.inputs[0].name)
48+
self.assertEqual(ctf_input[0].datatype, self.json_class.variables.inputs[0].datatype)
49+
self.assertEqual(ctf_input[0].distribution, self.json_class.variables.inputs[0].distribution)
5050

5151
def test_set_outputs(self):
5252
ctf_output = [Output("test_output", float)]
53-
self.assertEqual(ctf_output[0].name, self.json_class.outputs[0].name)
54-
self.assertEqual(ctf_output[0].datatype, self.json_class.outputs[0].datatype)
53+
self.assertEqual(ctf_output[0].name, self.json_class.variables.outputs[0].name)
54+
self.assertEqual(ctf_output[0].datatype, self.json_class.variables.outputs[0].datatype)
5555

5656
def test_set_metas(self):
5757
self.json_class.set_variables(self.input_dict_list, self.output_dict_list, self.meta_dict_list)
5858
ctf_meta = [Meta("test_meta", float, populate_example)]
59-
self.assertEqual(ctf_meta[0].name, self.json_class.metas[0].name)
60-
self.assertEqual(ctf_meta[0].datatype, self.json_class.metas[0].datatype)
59+
self.assertEqual(ctf_meta[0].name, self.json_class.variables.metas[0].name)
60+
self.assertEqual(ctf_meta[0].datatype, self.json_class.variables.metas[0].datatype)
6161

6262
def test_argparse(self):
6363
args = self.json_class.get_args(["--data_path=data.csv", "--dag_path=dag.dot", "--json_path=tests.json"])

0 commit comments

Comments
 (0)