Skip to content

Commit e23b9b1

Browse files
Merge pull request #162 from CITCOM-project/json_read_multiple_files
Json read multiple files
2 parents 505cca6 + 2ca274d commit e23b9b1

File tree

8 files changed

+91
-66
lines changed

8 files changed

+91
-66
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,22 +127,23 @@ def run_system_with_input_configuration(self, input_configuration: dict) -> pd.D
127127

128128

129129
class ObservationalDataCollector(DataCollector):
130-
"""A data collector that extracts data that is relevant to the specified scenario from a csv of execution data."""
130+
"""A data collector that extracts data that is relevant to the specified scenario from a dataframe of execution
131+
data."""
131132

132-
def __init__(self, scenario: Scenario, csv_path: str):
133+
def __init__(self, scenario: Scenario, data: pd.DataFrame):
133134
super().__init__(scenario)
134-
self.csv_path = csv_path
135+
self.data = data
135136

136137
def collect_data(self, **kwargs) -> pd.DataFrame:
137-
"""Read a csv containing execution data for the system-under-test into a pandas dataframe and filter to remove
138+
"""Read a pandas dataframe and filter to remove
138139
any data which is invalid for the scenario-under-test.
139140
140141
Data is invalid if it does not meet the constraints outlined in the scenario-under-test (Scenario).
141142
142143
:return: A pandas dataframe containing execution data that is valid for the scenario-under-test.
143144
"""
144145

145-
execution_data_df = pd.read_csv(self.csv_path, **kwargs)
146+
execution_data_df = self.data
146147
for meta in self.scenario.metas():
147148
meta.populate(execution_data_df)
148149
scenario_execution_data_df = self.filter_valid_data(execution_data_df)

causal_testing/json_front/json_class.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,20 @@ class JsonUtility(ABC):
4646
def __init__(self, log_path):
4747
self.paths = None
4848
self.variables = None
49-
self.data = None
49+
self.data = []
5050
self.test_plan = None
5151
self.modelling_scenario = None
5252
self.causal_specification = None
5353
self.setup_logger(log_path)
5454

55-
def set_paths(self, json_path: str, dag_path: str, data_path: str):
55+
def set_paths(self, json_path: str, dag_path: str, data_paths: str):
5656
"""
5757
Takes a path of the directory containing all scenario specific files and creates individual paths for each file
5858
:param json_path: string path representation to .json file containing test specifications
5959
:param dag_path: string path representation to the .dot file containing the Causal DAG
6060
:param data_path: string path representation to the data file
6161
"""
62-
self.paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_path=data_path)
62+
self.paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths)
6363

6464
def set_variables(self, inputs: list[dict], outputs: list[dict], metas: list[dict]):
6565
"""Populate the Causal Variables
@@ -132,14 +132,15 @@ def _json_parse(self):
132132
"""Parse a JSON input file into inputs, outputs, metas and a test plan"""
133133
with open(self.paths.json_path, encoding="utf-8") as f:
134134
self.test_plan = json.load(f)
135-
136-
self.data = pd.read_csv(self.paths.data_path)
135+
for data_file in self.paths.data_paths:
136+
df = pd.read_csv(data_file, header=0)
137+
self.data.append(df)
138+
self.data = pd.concat(self.data)
137139

138140
def _populate_metas(self):
139141
"""
140142
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
141143
"""
142-
143144
for meta in self.variables.metas:
144145
meta.populate(self.data)
145146

@@ -193,8 +194,10 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
193194
- causal_test_engine - Test Engine instance for the test being run
194195
- estimation_model - Estimator instance for the test being run
195196
"""
196-
data_collector = ObservationalDataCollector(self.modelling_scenario, self.paths.data_path)
197+
198+
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data)
197199
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
200+
198201
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
199202
treatment_var = causal_test_case.treatment_variable
200203
minimal_adjustment_set = minimal_adjustment_set - {treatment_var}
@@ -252,6 +255,7 @@ def get_args(test_args=None) -> argparse.Namespace:
252255
"--data_path",
253256
help="Specify path to file containing runtime data",
254257
required=True,
258+
nargs="+",
255259
)
256260
parser.add_argument(
257261
"--dag_path",
@@ -277,12 +281,12 @@ class JsonClassPaths:
277281

278282
json_path: Path
279283
dag_path: Path
280-
data_path: Path
284+
data_paths: list[Path]
281285

282-
def __init__(self, json_path: str, dag_path: str, data_path: str):
286+
def __init__(self, json_path: str, dag_path: str, data_paths: str):
283287
self.json_path = Path(json_path)
284288
self.dag_path = Path(dag_path)
285-
self.data_path = Path(data_path)
289+
self.data_paths = [Path(path) for path in data_paths]
286290

287291

288292
@dataclass()

causal_testing/testing/causal_test_engine.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,11 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
7171
minimal_adjustment_set = minimal_adjustment_set - set(edge.treatment_variable.name)
7272
minimal_adjustment_set = minimal_adjustment_set - set(edge.outcome_variable.name)
7373

74-
variables_for_positivity = (
75-
list(minimal_adjustment_set) + [edge.treatment_variable.name] + [edge.outcome_variable.name]
76-
)
74+
variables_for_positivity = list(minimal_adjustment_set) + [
75+
edge.treatment_variable.name,
76+
edge.outcome_variable.name,
77+
]
78+
7779
if self._check_positivity_violation(variables_for_positivity):
7880
raise ValueError("POSITIVITY VIOLATION -- Cannot proceed.")
7981

@@ -209,13 +211,15 @@ def _check_positivity_violation(self, variables_list):
209211
:param variables_list: The list of variables for which positivity must be satisfied.
210212
:return: True if positivity is violated, False otherwise.
211213
"""
212-
if not set(variables_list).issubset(self.scenario_execution_data_df.columns):
214+
if not (set(variables_list) - {x.name for x in self.scenario.hidden_variables()}).issubset(
215+
self.scenario_execution_data_df.columns
216+
):
213217
missing_variables = set(variables_list) - set(self.scenario_execution_data_df.columns)
214218
logger.warning(
215-
"Positivity violation: missing data for variables {missing_variables}.\n"
219+
"Positivity violation: missing data for variables %s.\n"
216220
"Causal inference is only valid if a well-specified parametric model is used.\n"
217221
"Alternatively, consider restricting analysis to executions without the variables:"
218-
" %s.",
222+
".",
219223
missing_variables,
220224
)
221225
return True

tests/data_collection_tests/test_observational_data_collector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,18 @@ class Color(Enum):
3838

3939
def test_not_all_variables_in_data(self):
4040
scenario = Scenario({self.X1, self.X2, self.X3, self.X4})
41-
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path)
41+
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df)
4242
self.assertRaises(IndexError, observational_data_collector.collect_data)
4343

4444
def test_all_variables_in_data(self):
4545
scenario = Scenario({self.X1, self.X2, self.X3, self.Y1, self.Y2})
46-
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path)
46+
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df)
4747
df = observational_data_collector.collect_data(index_col=0)
4848
assert df.equals(self.observational_df), f"\n{df}\nwas not equal to\n{self.observational_df}"
4949

5050
def test_data_constraints(self):
5151
scenario = Scenario({self.X1, self.X2, self.X3, self.Y1, self.Y2}, {self.X1.z3 > 2})
52-
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path)
52+
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df)
5353
df = observational_data_collector.collect_data(index_col=0)
5454
expected = self.observational_df.loc[[2, 3]]
5555
assert df.equals(expected), f"\n{df}\nwas not equal to\n{expected}"
@@ -60,7 +60,7 @@ def populate_m(data):
6060

6161
meta = Meta("M", int, populate_m)
6262
scenario = Scenario({self.X1, meta})
63-
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path)
63+
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df)
6464
data = observational_data_collector.collect_data()
6565
assert all((m == 2 * x1 for x1, m in zip(data["X1"], data["M"])))
6666

tests/json_front_tests/test_json_class.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ def setUp(self) -> None:
2626
dag_file_name = "dag.dot"
2727
data_file_name = "data.csv"
2828
test_data_dir_path = Path("tests/resources/data")
29-
self.json_path = test_data_dir_path / json_file_name
30-
self.dag_path = test_data_dir_path / dag_file_name
31-
self.data_path = test_data_dir_path / data_file_name
29+
self.json_path = str(test_data_dir_path / json_file_name)
30+
self.dag_path = str(test_data_dir_path / dag_file_name)
31+
self.data_path = [str(test_data_dir_path / data_file_name)]
3232
self.json_class = JsonUtility("logs.log")
3333
self.example_distribution = scipy.stats.uniform(1, 10)
3434
self.input_dict_list = [{"name": "test_input", "datatype": float, "distribution": self.example_distribution}]
@@ -40,7 +40,7 @@ def setUp(self) -> None:
4040
def test_setting_paths(self):
4141
self.assertEqual(self.json_class.paths.json_path, Path(self.json_path))
4242
self.assertEqual(self.json_class.paths.dag_path, Path(self.dag_path))
43-
self.assertEqual(self.json_class.paths.data_path, Path(self.data_path))
43+
self.assertEqual(self.json_class.paths.data_paths, [Path(self.data_path[0])]) # Needs to be list of Paths
4444

4545
def test_set_inputs(self):
4646
ctf_input = [Input("test_input", float, self.example_distribution)]
@@ -61,7 +61,7 @@ def test_set_metas(self):
6161

6262
def test_argparse(self):
6363
args = self.json_class.get_args(["--data_path=data.csv", "--dag_path=dag.dot", "--json_path=tests.json"])
64-
self.assertEqual(args.data_path, "data.csv")
64+
self.assertEqual(args.data_path, ["data.csv"])
6565
self.assertEqual(args.dag_path, "dag.dot")
6666
self.assertEqual(args.json_path, "tests.json")
6767

tests/testing_tests/test_causal_test_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def setUp(self) -> None:
5959

6060
# 5. Create observational data collector
6161
# Obsolete?
62-
self.data_collector = ObservationalDataCollector(self.scenario, self.observational_data_csv_path)
62+
self.data_collector = ObservationalDataCollector(self.scenario, df)
6363

6464
# 5. Create causal test engine
6565
self.causal_test_engine = CausalTestEngine(self.causal_specification, self.data_collector)

tests/testing_tests/test_causal_test_outcome.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,17 @@ def test_None_ci(self):
2727

2828
self.assertIsNone(ctr.ci_low())
2929
self.assertIsNone(ctr.ci_high())
30-
self.assertEqual(ctr.to_dict(),
31-
{"treatment": "A",
32-
"control_value": 0,
33-
"treatment_value": 1,
34-
"outcome": "A",
35-
"adjustment_set": set(),
36-
"test_value": test_value})
30+
self.assertEqual(
31+
ctr.to_dict(),
32+
{
33+
"treatment": "A",
34+
"control_value": 0,
35+
"treatment_value": 1,
36+
"outcome": "A",
37+
"adjustment_set": set(),
38+
"test_value": test_value,
39+
},
40+
)
3741

3842
def test_empty_adjustment_set(self):
3943
test_value = TestValue(type="ate", value=0)
@@ -46,13 +50,18 @@ def test_empty_adjustment_set(self):
4650

4751
self.assertIsNone(ctr.ci_low())
4852
self.assertIsNone(ctr.ci_high())
49-
self.assertEqual(str(ctr), ("Causal Test Result\n==============\n"
50-
"Treatment: A\n"
51-
"Control value: 0\n"
52-
"Treatment value: 1\n"
53-
"Outcome: A\n"
54-
"Adjustment set: set()\n"
55-
"ate: 0\n" ))
53+
self.assertEqual(
54+
str(ctr),
55+
(
56+
"Causal Test Result\n==============\n"
57+
"Treatment: A\n"
58+
"Control value: 0\n"
59+
"Treatment value: 1\n"
60+
"Outcome: A\n"
61+
"Adjustment set: set()\n"
62+
"ate: 0\n"
63+
),
64+
)
5665

5766
def test_exactValue_pass(self):
5867
test_value = TestValue(type="ate", value=5.05)
@@ -97,20 +106,29 @@ def test_someEffect_fail(self):
97106
)
98107
ev = SomeEffect()
99108
self.assertFalse(ev.apply(ctr))
100-
self.assertEqual(str(ctr), ("Causal Test Result\n==============\n"
101-
"Treatment: A\n"
102-
"Control value: 0\n"
103-
"Treatment value: 1\n"
104-
"Outcome: A\n"
105-
"Adjustment set: set()\n"
106-
"ate: 0\n"
107-
"Confidence intervals: [-0.1, 0.2]\n" ))
108-
self.assertEqual(ctr.to_dict(),
109-
{"treatment": "A",
110-
"control_value": 0,
111-
"treatment_value": 1,
112-
"outcome": "A",
113-
"adjustment_set": set(),
114-
"test_value": test_value,
115-
"ci_low": -0.1,
116-
"ci_high": 0.2})
109+
self.assertEqual(
110+
str(ctr),
111+
(
112+
"Causal Test Result\n==============\n"
113+
"Treatment: A\n"
114+
"Control value: 0\n"
115+
"Treatment value: 1\n"
116+
"Outcome: A\n"
117+
"Adjustment set: set()\n"
118+
"ate: 0\n"
119+
"Confidence intervals: [-0.1, 0.2]\n"
120+
),
121+
)
122+
self.assertEqual(
123+
ctr.to_dict(),
124+
{
125+
"treatment": "A",
126+
"control_value": 0,
127+
"treatment_value": 1,
128+
"outcome": "A",
129+
"adjustment_set": set(),
130+
"test_value": test_value,
131+
"ci_low": -0.1,
132+
"ci_high": 0.2,
133+
},
134+
)

tests/testing_tests/test_causal_test_suite.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ def setUp(self) -> None:
4040
df = pd.DataFrame({"D": list(np.random.normal(60, 10, 1000))}) # D = exogenous
4141
df["A"] = [1 if d > 50 else 0 for d in df["D"]]
4242
df["C"] = df["D"] + (4 * (df["A"] + 2)) # C = (4*(A+2)) + D
43-
self.observational_data_csv_path = os.path.join(temp_dir_path, "observational_data.csv")
44-
df.to_csv(self.observational_data_csv_path, index=False)
45-
43+
self.df = df
4644
self.causal_dag = CausalDAG(dag_dot_path)
4745

4846
# 3. Specify data structures required for test suite
@@ -126,6 +124,6 @@ def create_causal_test_engine(self):
126124
"""
127125
causal_specification = CausalSpecification(self.scenario, self.causal_dag)
128126

129-
data_collector = ObservationalDataCollector(self.scenario, self.observational_data_csv_path)
127+
data_collector = ObservationalDataCollector(self.scenario, self.df)
130128
causal_test_engine = CausalTestEngine(causal_specification, data_collector)
131129
return causal_test_engine

0 commit comments

Comments
 (0)