Skip to content

Commit 7871296

Browse files
committed
Testing main.py
1 parent ba72d7e commit 7871296

File tree

3 files changed

+54
-7
lines changed

3 files changed

+54
-7
lines changed

tests/main_tests/test_main.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import unittest
22
import shutil
3+
import json
34
from pathlib import Path
4-
from causal_testing.main import CausalTestingPaths
5+
from causal_testing.main import CausalTestingPaths, CausalTestingFramework
56

67

78
class TestCausalTestingPaths(unittest.TestCase):
@@ -35,3 +36,45 @@ def test_output_file_created(self):
3536
def tearDown(self):
3637
if self.output_path.parent.exists():
3738
shutil.rmtree(self.output_path.parent)
39+
40+
41+
class TestCausalTestingFramework(unittest.TestCase):
42+
def setUp(self):
43+
self.dag_path = "tests/resources/data/dag.dot"
44+
self.data_paths = ["tests/resources/data/data.csv"]
45+
self.test_config_path = "tests/resources/data/tests.json"
46+
self.output_path = Path("results/results.json")
47+
48+
def test_ctf(self):
49+
# Create paths object
50+
paths = CausalTestingPaths(
51+
dag_path=self.dag_path,
52+
data_paths=self.data_paths,
53+
test_config_path=self.test_config_path,
54+
output_path=self.output_path,
55+
)
56+
57+
# Create and setup framework
58+
framework = CausalTestingFramework(paths)
59+
framework.setup()
60+
61+
# Load and run tests
62+
framework.load_tests()
63+
results = framework.run_tests()
64+
65+
# Save results
66+
framework.save_results(results)
67+
68+
with open(self.test_config_path, "r", encoding="utf-8") as f:
69+
test_configs = json.load(f)
70+
71+
tests_passed = [
72+
test_case.expected_causal_effect.apply(result) if result.test_value.type != "Error" else False
73+
for test_config, test_case, result in zip(test_configs["tests"], framework.test_cases, results)
74+
]
75+
76+
self.assertEqual(tests_passed, [True])
77+
78+
def tearDown(self):
79+
if self.output_path.parent.exists():
80+
shutil.rmtree(self.output_path.parent)

tests/resources/data/data.csv

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
1-
index,test_input,test_input_no_dist,test_output
2-
0,1.0,1.0,2.0
1+
test_input,test_input_no_dist,test_output,B,C
2+
1.0,1.1,2.2,0,0
3+
2.0,1.1,2.8,0,0
4+
3.0,1.0,1.0,0,0
5+
4.0,1.2,6.0,0,0
6+
5.0,0.9,2.5,0,0

tests/resources/data/tests.json

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
{
22
"tests": [{
33
"name": "test1",
4-
"mutations": {},
5-
"estimator": null,
6-
"estimate_type": null,
4+
"treatment_variable": "test_input",
5+
"estimator": "LinearRegressionEstimator",
6+
"estimate_type": "coefficient",
77
"effect_modifiers": [],
8-
"expected_effect": {},
8+
"expected_effect": {"test_output": "NoEffect"},
99
"skip": false
1010
}]
1111
}

0 commit comments

Comments
 (0)