Skip to content

Commit 8da66cd

Browse files
Update tests
1 parent 79683e0 commit 8da66cd

File tree

4 files changed

+10
-12
lines changed

4 files changed

+10
-12
lines changed

tests/data_collection_tests/test_observational_data_collector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,18 @@ class Color(Enum):
3838

3939
def test_not_all_variables_in_data(self):
4040
scenario = Scenario({self.X1, self.X2, self.X3, self.X4})
41-
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path)
41+
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df)
4242
self.assertRaises(IndexError, observational_data_collector.collect_data)
4343

4444
def test_all_variables_in_data(self):
4545
scenario = Scenario({self.X1, self.X2, self.X3, self.Y1, self.Y2})
46-
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path)
46+
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df)
4747
df = observational_data_collector.collect_data(index_col=0)
4848
assert df.equals(self.observational_df), f"\n{df}\nwas not equal to\n{self.observational_df}"
4949

5050
def test_data_constraints(self):
5151
scenario = Scenario({self.X1, self.X2, self.X3, self.Y1, self.Y2}, {self.X1.z3 > 2})
52-
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path)
52+
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df)
5353
df = observational_data_collector.collect_data(index_col=0)
5454
expected = self.observational_df.loc[[2, 3]]
5555
assert df.equals(expected), f"\n{df}\nwas not equal to\n{expected}"
@@ -60,7 +60,7 @@ def populate_m(data):
6060

6161
meta = Meta("M", int, populate_m)
6262
scenario = Scenario({self.X1, meta})
63-
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path)
63+
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df)
6464
data = observational_data_collector.collect_data()
6565
assert all((m == 2 * x1 for x1, m in zip(data["X1"], data["M"])))
6666

tests/json_front_tests/test_json_class.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def setUp(self) -> None:
2828
test_data_dir_path = Path("tests/resources/data")
2929
self.json_path = str(test_data_dir_path / json_file_name)
3030
self.dag_path = str(test_data_dir_path / dag_file_name)
31-
self.data_path = str(test_data_dir_path / data_file_name)
31+
self.data_path = [str(test_data_dir_path / data_file_name)]
3232
self.json_class = JsonUtility("logs.log")
3333
self.example_distribution = scipy.stats.uniform(1, 10)
3434
self.input_dict_list = [{"name": "test_input", "type": float, "distribution": self.example_distribution}]
@@ -40,7 +40,7 @@ def setUp(self) -> None:
4040
def test_setting_paths(self):
4141
self.assertEqual(self.json_class.paths.json_path, Path(self.json_path))
4242
self.assertEqual(self.json_class.paths.dag_path, Path(self.dag_path))
43-
self.assertEqual(self.json_class.paths.data_path, Path(self.data_path))
43+
self.assertEqual(self.json_class.paths.data_paths, [Path(self.data_path[0])]) # Needs to be list of Paths
4444

4545
def test_set_inputs(self):
4646
ctf_input = [Input("test_input", float, self.example_distribution)]
@@ -61,7 +61,7 @@ def test_set_metas(self):
6161

6262
def test_argparse(self):
6363
args = self.json_class.get_args(["--data_path=data.csv", "--dag_path=dag.dot", "--json_path=tests.json"])
64-
self.assertEqual(args.data_path, "data.csv")
64+
self.assertEqual(args.data_path, ["data.csv"])
6565
self.assertEqual(args.dag_path, "dag.dot")
6666
self.assertEqual(args.json_path, "tests.json")
6767

tests/testing_tests/test_causal_test_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def setUp(self) -> None:
5959

6060
# 5. Create observational data collector
6161
# Obsolete?
62-
self.data_collector = ObservationalDataCollector(self.scenario, self.observational_data_csv_path)
62+
self.data_collector = ObservationalDataCollector(self.scenario, df)
6363

6464
# 5. Create causal test engine
6565
self.causal_test_engine = CausalTestEngine(self.causal_specification, self.data_collector)

tests/testing_tests/test_causal_test_suite.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ def setUp(self) -> None:
4040
df = pd.DataFrame({"D": list(np.random.normal(60, 10, 1000))}) # D = exogenous
4141
df["A"] = [1 if d > 50 else 0 for d in df["D"]]
4242
df["C"] = df["D"] + (4 * (df["A"] + 2)) # C = (4*(A+2)) + D
43-
self.observational_data_csv_path = os.path.join(temp_dir_path, "observational_data.csv")
44-
df.to_csv(self.observational_data_csv_path, index=False)
45-
43+
self.df = df
4644
self.causal_dag = CausalDAG(dag_dot_path)
4745

4846
# 3. Specify data structures required for test suite
@@ -126,6 +124,6 @@ def create_causal_test_engine(self):
126124
"""
127125
causal_specification = CausalSpecification(self.scenario, self.causal_dag)
128126

129-
data_collector = ObservationalDataCollector(self.scenario, self.observational_data_csv_path)
127+
data_collector = ObservationalDataCollector(self.scenario, self.df)
130128
causal_test_engine = CausalTestEngine(causal_specification, data_collector)
131129
return causal_test_engine

0 commit comments

Comments
 (0)