Skip to content

Commit 555e830

Browse files
committed
Merge branch 'json_read_multiple_files' of github.com:CITCOM-project/CausalTestingFramework into json_read_multiple_files
2 parents 203ca69 + 46f683f commit 555e830

File tree

6 files changed

+17
-18
lines changed

6 files changed

+17
-18
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,15 @@ def run_system_with_input_configuration(self, input_configuration: dict) -> pd.D
122122

123123

124124
class ObservationalDataCollector(DataCollector):
125-
"""A data collector that extracts data that is relevant to the specified scenario from a csv of execution data."""
125+
"""A data collector that extracts data that is relevant to the specified scenario from a dataframe of execution
126+
data."""
126127

127128
def __init__(self, scenario: Scenario, data: pd.DataFrame):
128129
super().__init__(scenario)
129130
self.data = data
130131

131132
def collect_data(self, **kwargs) -> pd.DataFrame:
132-
"""Read a csv containing execution data for the system-under-test into a pandas dataframe and filter to remove
133+
"""Read a pandas dataframe and filter to remove
133134
any data which is invalid for the scenario-under-test.
134135
135136
Data is invalid if it does not meet the constraints outlined in the scenario-under-test (Scenario).

causal_testing/json_front/json_class.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import argparse
55
import json
66
import logging
7-
import tempfile
87

98
from abc import ABC
109
from dataclasses import dataclass
@@ -47,7 +46,7 @@ class JsonUtility(ABC):
4746
def __init__(self, log_path):
4847
self.paths = None
4948
self.variables = None
50-
self.data = list()
49+
self.data = []
5150
self.test_plan = None
5251
self.modelling_scenario = None
5352
self.causal_specification = None
@@ -137,6 +136,7 @@ def _json_parse(self):
137136
df = pd.read_csv(data_file, header=0)
138137
self.data.append(df)
139138
self.data = pd.concat(self.data)
139+
140140
def _populate_metas(self):
141141
"""
142142
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
@@ -255,7 +255,7 @@ def get_args(test_args=None) -> argparse.Namespace:
255255
"--data_path",
256256
help="Specify path to file containing runtime data",
257257
required=True,
258-
nargs='+',
258+
nargs="+",
259259
)
260260
parser.add_argument(
261261
"--dag_path",
@@ -286,7 +286,7 @@ class JsonClassPaths:
286286
def __init__(self, json_path: str, dag_path: str, data_paths: str):
287287
self.json_path = Path(json_path)
288288
self.dag_path = Path(dag_path)
289-
self.data_paths = [Path(path) for path in [data_paths]]
289+
self.data_paths = [Path(path) for path in data_paths]
290290

291291

292292
@dataclass()

tests/data_collection_tests/test_observational_data_collector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,18 @@ class Color(Enum):
3838

3939
def test_not_all_variables_in_data(self):
4040
scenario = Scenario({self.X1, self.X2, self.X3, self.X4})
41-
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path)
41+
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df)
4242
self.assertRaises(IndexError, observational_data_collector.collect_data)
4343

4444
def test_all_variables_in_data(self):
4545
scenario = Scenario({self.X1, self.X2, self.X3, self.Y1, self.Y2})
46-
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path)
46+
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df)
4747
df = observational_data_collector.collect_data(index_col=0)
4848
assert df.equals(self.observational_df), f"\n{df}\nwas not equal to\n{self.observational_df}"
4949

5050
def test_data_constraints(self):
5151
scenario = Scenario({self.X1, self.X2, self.X3, self.Y1, self.Y2}, {self.X1.z3 > 2})
52-
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path)
52+
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df)
5353
df = observational_data_collector.collect_data(index_col=0)
5454
expected = self.observational_df.loc[[2, 3]]
5555
assert df.equals(expected), f"\n{df}\nwas not equal to\n{expected}"
@@ -60,7 +60,7 @@ def populate_m(data):
6060

6161
meta = Meta("M", int, populate_m)
6262
scenario = Scenario({self.X1, meta})
63-
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path)
63+
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df)
6464
data = observational_data_collector.collect_data()
6565
assert all((m == 2 * x1 for x1, m in zip(data["X1"], data["M"])))
6666

tests/json_front_tests/test_json_class.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def setUp(self) -> None:
2828
test_data_dir_path = Path("tests/resources/data")
2929
self.json_path = str(test_data_dir_path / json_file_name)
3030
self.dag_path = str(test_data_dir_path / dag_file_name)
31-
self.data_path = str(test_data_dir_path / data_file_name)
31+
self.data_path = [str(test_data_dir_path / data_file_name)]
3232
self.json_class = JsonUtility("logs.log")
3333
self.example_distribution = scipy.stats.uniform(1, 10)
3434
self.input_dict_list = [{"name": "test_input", "type": float, "distribution": self.example_distribution}]
@@ -40,7 +40,7 @@ def setUp(self) -> None:
4040
def test_setting_paths(self):
4141
self.assertEqual(self.json_class.paths.json_path, Path(self.json_path))
4242
self.assertEqual(self.json_class.paths.dag_path, Path(self.dag_path))
43-
self.assertEqual(self.json_class.paths.data_path, Path(self.data_path))
43+
self.assertEqual(self.json_class.paths.data_paths, [Path(self.data_path[0])]) # Needs to be list of Paths
4444

4545
def test_set_inputs(self):
4646
ctf_input = [Input("test_input", float, self.example_distribution)]
@@ -61,7 +61,7 @@ def test_set_metas(self):
6161

6262
def test_argparse(self):
6363
args = self.json_class.get_args(["--data_path=data.csv", "--dag_path=dag.dot", "--json_path=tests.json"])
64-
self.assertEqual(args.data_path, "data.csv")
64+
self.assertEqual(args.data_path, ["data.csv"])
6565
self.assertEqual(args.dag_path, "dag.dot")
6666
self.assertEqual(args.json_path, "tests.json")
6767

tests/testing_tests/test_causal_test_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def setUp(self) -> None:
5959

6060
# 5. Create observational data collector
6161
# Obsolete?
62-
self.data_collector = ObservationalDataCollector(self.scenario, self.observational_data_csv_path)
62+
self.data_collector = ObservationalDataCollector(self.scenario, df)
6363

6464
# 5. Create causal test engine
6565
self.causal_test_engine = CausalTestEngine(self.causal_specification, self.data_collector)

tests/testing_tests/test_causal_test_suite.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ def setUp(self) -> None:
4040
df = pd.DataFrame({"D": list(np.random.normal(60, 10, 1000))}) # D = exogenous
4141
df["A"] = [1 if d > 50 else 0 for d in df["D"]]
4242
df["C"] = df["D"] + (4 * (df["A"] + 2)) # C = (4*(A+2)) + D
43-
self.observational_data_csv_path = os.path.join(temp_dir_path, "observational_data.csv")
44-
df.to_csv(self.observational_data_csv_path, index=False)
45-
43+
self.df = df
4644
self.causal_dag = CausalDAG(dag_dot_path)
4745

4846
# 3. Specify data structures required for test suite
@@ -126,6 +124,6 @@ def create_causal_test_engine(self):
126124
"""
127125
causal_specification = CausalSpecification(self.scenario, self.causal_dag)
128126

129-
data_collector = ObservationalDataCollector(self.scenario, self.observational_data_csv_path)
127+
data_collector = ObservationalDataCollector(self.scenario, self.df)
130128
causal_test_engine = CausalTestEngine(causal_specification, data_collector)
131129
return causal_test_engine

0 commit comments

Comments
 (0)