Skip to content

Commit d394bdd

Browse files
Add test cases for causal_test_suite
1 parent 2d68658 commit d394bdd

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import unittest
2+
import os
3+
import numpy as np
4+
import pandas as pd
5+
from causal_testing.testing.causal_test_engine import CausalTestEngine
6+
from causal_testing.testing.causal_test_engine import CausalTestSuite
7+
from causal_testing.testing.causal_test_case import CausalTestCase
8+
from causal_testing.testing.base_test_case import BaseTestCase
9+
from causal_testing.specification.variable import Input, Output
10+
from causal_testing.testing.causal_test_outcome import ExactValue
11+
from causal_testing.testing.estimators import CausalForestEstimator, LinearRegressionEstimator
12+
from causal_testing.specification.causal_specification import CausalSpecification, Scenario
13+
from causal_testing.data_collection.data_collector import ObservationalDataCollector
14+
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
15+
from causal_testing.specification.causal_dag import CausalDAG
16+
17+
18+
class TestCausalTestSuite(unittest.TestCase):
19+
"""
20+
21+
"""
22+
23+
def setUp(self) -> None:
24+
self.test_suite = CausalTestSuite()
25+
A = Input("A", float)
26+
self.A = A
27+
C = Output("C", float)
28+
self.C = C
29+
D = Output("D", float)
30+
self.D = D
31+
self.base_test_case = BaseTestCase(A, C)
32+
self.expected_causal_effect = ExactValue(4)
33+
test_list = [CausalTestCase(self.base_test_case,
34+
self.expected_causal_effect,
35+
0,
36+
1, ),
37+
CausalTestCase(self.base_test_case,
38+
self.expected_causal_effect,
39+
0,
40+
2)]
41+
estimators = [LinearRegressionEstimator]
42+
self.test_suite.add_test_object(base_test_case=self.base_test_case,
43+
causal_test_case_list=test_list,
44+
estimators=estimators)
45+
46+
temp_dir_path = create_temp_dir_if_non_existent()
47+
dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
48+
dag_dot = """digraph G { A -> C; D -> A; D -> C}"""
49+
f = open(dag_dot_path, "w")
50+
f.write(dag_dot)
51+
f.close()
52+
self.causal_dag = CausalDAG(dag_dot_path)
53+
self.scenario = Scenario({A, C, D})
54+
55+
np.random.seed(1)
56+
df = pd.DataFrame({"D": list(np.random.normal(60, 10, 1000))}) # D = exogenous
57+
df["A"] = [1 if d > 50 else 0 for d in df["D"]]
58+
df["C"] = df["D"] + (4 * (df["A"] + 2)) # C = (4*(A+2)) + D
59+
self.observational_data_csv_path = os.path.join(temp_dir_path, "observational_data.csv")
60+
df.to_csv(self.observational_data_csv_path, index=False)
61+
62+
def test_adding_test_object(self):
63+
test_suite = CausalTestSuite()
64+
test_list = [CausalTestCase(self.base_test_case,
65+
self.expected_causal_effect,
66+
0,
67+
1)]
68+
estimators = [LinearRegressionEstimator]
69+
test_suite.add_test_object(base_test_case=self.base_test_case,
70+
causal_test_case_list=test_list,
71+
estimators=estimators)
72+
manual_test_object = {
73+
self.base_test_case: {"tests": test_list, "estimators": estimators, "estimate_type": "ate"}}
74+
self.assertEqual(test_suite.test_suite, manual_test_object)
75+
76+
def test_return_single_test_object(self):
77+
base_test_case = BaseTestCase(self.A, self.D)
78+
79+
test_list = [CausalTestCase(self.base_test_case,
80+
self.expected_causal_effect,
81+
0,
82+
1)]
83+
estimators = [LinearRegressionEstimator]
84+
self.test_suite.add_test_object(base_test_case=base_test_case,
85+
causal_test_case_list=test_list,
86+
estimators=estimators)
87+
88+
manual_test_case = {"tests": test_list, "estimators": estimators, "estimate_type": "ate"}
89+
90+
test_case = self.test_suite.get_single_test_object(base_test_case)
91+
92+
self.assertEqual(test_case, manual_test_case)
93+
94+
def test_execute_test_suite_single_test_case(self):
95+
causal_specification = CausalSpecification(self.scenario, self.causal_dag)
96+
97+
data_collector = ObservationalDataCollector(self.scenario, self.observational_data_csv_path)
98+
causal_test_engine = CausalTestEngine(causal_specification, data_collector)
99+
100+
causal_test_results = causal_test_engine.execute_test_suite(test_suite=self.test_suite)
101+
breakpoint()
102+
self.assertAlmostEqual(causal_test_results[self.base_test_case][0][0].ate, 4, delta=1e-10)
103+
104+
def test_execute_test_suite_multiple_test_case(self):
105+
pass
106+
107+
def test_execute_test_suite_multiple_estimators(self):
108+
pass

0 commit comments

Comments
 (0)