|
1 | 1 | """This module contains the JsonUtility class, details of using this class can be found here:
|
2 | 2 | https://causal-testing-framework.readthedocs.io/en/latest/json_front_end.html"""
|
| 3 | + |
3 | 4 | import argparse
|
4 | 5 | import json
|
5 | 6 | import logging
|
6 | 7 |
|
7 | 8 | from abc import ABC
|
| 9 | +from dataclasses import dataclass |
8 | 10 | from pathlib import Path
|
9 | 11 |
|
10 | 12 | import pandas as pd
|
@@ -42,49 +44,45 @@ class JsonUtility(ABC):
|
42 | 44 | """
|
43 | 45 |
|
44 | 46 | def __init__(self, log_path):
|
45 |
| - self.json_path = None |
46 |
| - self.dag_path = None |
47 |
| - self.data_path = None |
48 |
| - self.inputs = None |
49 |
| - self.outputs = None |
50 |
| - self.metas = None |
| 47 | + self.paths = None |
| 48 | + self.variables = None |
51 | 49 | self.data = None
|
52 | 50 | self.test_plan = None
|
53 | 51 | self.modelling_scenario = None
|
54 | 52 | self.causal_specification = None
|
55 | 53 | self.setup_logger(log_path)
|
56 | 54 |
|
57 |
| - def set_path(self, json_path: str, dag_path: str, data_path: str): |
| 55 | + def set_paths(self, json_path: str, dag_path: str, data_path: str): |
58 | 56 | """
|
59 | 57 | Takes a path of the directory containing all scenario specific files and creates individual paths for each file
|
60 | 58 | :param json_path: string path representation to .json file containing test specifications
|
61 | 59 | :param dag_path: string path representation to the .dot file containing the Causal DAG
|
62 | 60 | :param data_path: string path representation to the data file
|
63 |
| - :returns: |
64 |
| - - json_path - |
65 |
| - - dag_path - |
66 |
| - - data_path - |
67 | 61 | """
|
68 |
| - self.json_path = Path(json_path) |
69 |
| - self.dag_path = Path(dag_path) |
70 |
| - self.data_path = Path(data_path) |
| 62 | + self.paths = JsonClassPaths( |
| 63 | + json_path=json_path, |
| 64 | + dag_path=dag_path, |
| 65 | + data_path=data_path |
| 66 | + ) |
71 | 67 |
|
72 |
| - def set_variables(self, inputs: dict, outputs: dict, metas: dict): |
| 68 | + def set_variables(self, inputs: list[dict], outputs: list[dict], metas: list[dict]): |
73 | 69 | """Populate the Causal Variables
|
74 | 70 | :param inputs:
|
75 | 71 | :param outputs:
|
76 | 72 | :param metas:
|
77 | 73 | """
|
78 |
| - self.inputs = [Input(i["name"], i["type"], i["distribution"]) for i in inputs] |
79 |
| - self.outputs = [Output(i["name"], i["type"], i.get("distribution", None)) for i in outputs] |
80 |
| - self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else [] |
81 | 74 |
|
| 75 | + self.variables = CausalVariables( |
| 76 | + inputs=inputs, |
| 77 | + outputs=outputs, |
| 78 | + metas=metas |
| 79 | + ) |
82 | 80 | def setup(self):
|
83 | 81 | """Function to populate all the necessary parts of the json_class needed to execute tests"""
|
84 |
| - self.modelling_scenario = Scenario(self.inputs + self.outputs + self.metas, None) |
| 82 | + self.modelling_scenario = Scenario(self.variables.inputs + self.variables.outputs + self.variables.metas, None) |
85 | 83 | self.modelling_scenario.setup_treatment_variables()
|
86 | 84 | self.causal_specification = CausalSpecification(
|
87 |
| - scenario=self.modelling_scenario, causal_dag=CausalDAG(self.dag_path) |
| 85 | + scenario=self.modelling_scenario, causal_dag=CausalDAG(self.paths.dag_path) |
88 | 86 | )
|
89 | 87 | self._json_parse()
|
90 | 88 | self._populate_metas()
|
@@ -139,20 +137,20 @@ def _execute_tests(self, concrete_tests, estimators, test, f_flag):
|
139 | 137 |
|
140 | 138 | def _json_parse(self):
|
141 | 139 | """Parse a JSON input file into inputs, outputs, metas and a test plan"""
|
142 |
| - with open(self.json_path, encoding="utf-8") as f: |
| 140 | + with open(self.paths.json_path, encoding="utf-8") as f: |
143 | 141 | self.test_plan = json.load(f)
|
144 | 142 |
|
145 |
| - self.data = pd.read_csv(self.data_path) |
| 143 | + self.data = pd.read_csv(self.paths.data_path) |
146 | 144 |
|
147 | 145 | def _populate_metas(self):
|
148 | 146 | """
|
149 | 147 | Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
|
150 | 148 | """
|
151 | 149 |
|
152 |
| - for meta in self.metas: |
| 150 | + for meta in self.variables.metas: |
153 | 151 | meta.populate(self.data)
|
154 | 152 |
|
155 |
| - for var in self.metas + self.outputs: |
| 153 | + for var in self.variables.metas + self.variables.outputs: |
156 | 154 | if not var.distribution:
|
157 | 155 | fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
|
158 | 156 | fitter.fit()
|
@@ -202,7 +200,7 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
|
202 | 200 | - causal_test_engine - Test Engine instance for the test being run
|
203 | 201 | - estimation_model - Estimator instance for the test being run
|
204 | 202 | """
|
205 |
| - data_collector = ObservationalDataCollector(self.modelling_scenario, self.data_path) |
| 203 | + data_collector = ObservationalDataCollector(self.modelling_scenario, self.paths.data_path) |
206 | 204 | causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
|
207 | 205 | minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
|
208 | 206 | treatment_var = causal_test_case.treatment_variable
|
@@ -273,3 +271,48 @@ def get_args(test_args=None) -> argparse.Namespace:
|
273 | 271 | required=True,
|
274 | 272 | )
|
275 | 273 | return parser.parse_args(test_args)
|
| 274 | + |
| 275 | + |
| 276 | +@dataclass |
| 277 | +class JsonClassPaths: |
| 278 | + """ |
| 279 | + A dataclass that converts strings of paths to Path objects for use in the JsonUtility class |
| 280 | + :param json_path: string path representation to .json file containing test specifications |
| 281 | + :param dag_path: string path representation to the .dot file containing the Causal DAG |
| 282 | + :param data_path: string path representation to the data file |
| 283 | + """ |
| 284 | + |
| 285 | + json_path: Path |
| 286 | + dag_path: Path |
| 287 | + data_path: Path |
| 288 | + |
| 289 | + def __init__( |
| 290 | + self, |
| 291 | + json_path: str, |
| 292 | + dag_path: str, |
| 293 | + data_path: str |
| 294 | + ): |
| 295 | + self.json_path = Path(json_path) |
| 296 | + self.dag_path = Path(dag_path) |
| 297 | + self.data_path = Path(data_path) |
| 298 | + |
| 299 | + |
| 300 | +@dataclass() |
| 301 | +class CausalVariables: |
| 302 | + """ |
| 303 | + A dataclass that converts |
| 304 | + """ |
| 305 | + |
| 306 | + inputs: list[Input] |
| 307 | + outputs: list[Output] |
| 308 | + metas: list[Meta] |
| 309 | + |
| 310 | + def __init__( |
| 311 | + self, |
| 312 | + inputs: list[dict], |
| 313 | + outputs: list[dict], |
| 314 | + metas: list[dict] |
| 315 | + ): |
| 316 | + self.inputs = [Input(i["name"], i["type"], i["distribution"]) for i in inputs] |
| 317 | + self.outputs = [Output(i["name"], i["type"], i.get("distribution", None)) for i in outputs] |
| 318 | + self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else [] |
0 commit comments