Skip to content

Commit 2cdb50f

Browse files
Docstrings and comments
1 parent c920b97 commit 2cdb50f

File tree

1 file changed

+51
-44
lines changed

1 file changed

+51
-44
lines changed

tests/testing_tests/test_causal_test_suite.py

Lines changed: 51 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,41 +16,26 @@
1616

1717

1818
class TestCausalTestSuite(unittest.TestCase):
19-
"""
20-
21-
"""
19+
"""Test the Test Suite object and it's implementation in the test engine using dummy data."""
2220

2321
def setUp(self) -> None:
24-
self.test_suite = CausalTestSuite()
22+
# 1. Create dummy Scenario and BaseTestCase
2523
A = Input("A", float)
2624
self.A = A
2725
C = Output("C", float)
2826
self.C = C
2927
D = Output("D", float)
3028
self.D = D
3129
self.base_test_case = BaseTestCase(A, C)
32-
self.expected_causal_effect = ExactValue(4)
33-
test_list = [CausalTestCase(self.base_test_case,
34-
self.expected_causal_effect,
35-
0,
36-
1, ),
37-
CausalTestCase(self.base_test_case,
38-
self.expected_causal_effect,
39-
0,
40-
2)]
41-
self.estimators = [LinearRegressionEstimator]
42-
self.test_suite.add_test_object(base_test_case=self.base_test_case,
43-
causal_test_case_list=test_list,
44-
estimators=self.estimators)
30+
self.scenario = Scenario({A, C, D})
4531

32+
# 2. Create DAG and dummy data and write to csvs
4633
temp_dir_path = create_temp_dir_if_non_existent()
4734
dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
4835
dag_dot = """digraph G { A -> C; D -> A; D -> C}"""
4936
f = open(dag_dot_path, "w")
5037
f.write(dag_dot)
5138
f.close()
52-
self.causal_dag = CausalDAG(dag_dot_path)
53-
self.scenario = Scenario({A, C, D})
5439

5540
np.random.seed(1)
5641
df = pd.DataFrame({"D": list(np.random.normal(60, 10, 1000))}) # D = exogenous
@@ -59,31 +44,49 @@ def setUp(self) -> None:
5944
self.observational_data_csv_path = os.path.join(temp_dir_path, "observational_data.csv")
6045
df.to_csv(self.observational_data_csv_path, index=False)
6146

47+
self.causal_dag = CausalDAG(dag_dot_path)
48+
49+
# 3. Specify data structures required for test suite
50+
self.expected_causal_effect = ExactValue(4)
51+
test_list = [
52+
CausalTestCase(
53+
self.base_test_case,
54+
self.expected_causal_effect,
55+
0,
56+
1,
57+
),
58+
CausalTestCase(self.base_test_case, self.expected_causal_effect, 0, 2),
59+
]
60+
self.estimators = [LinearRegressionEstimator]
61+
62+
# 3. Create test_suite and add a test
63+
self.test_suite = CausalTestSuite()
64+
self.test_suite.add_test_object(
65+
base_test_case=self.base_test_case, causal_test_case_list=test_list, estimators=self.estimators
66+
)
67+
6268
def test_adding_test_object(self):
69+
"test an object can be added to the test_suite using the add_test_object function"
6370
test_suite = CausalTestSuite()
64-
test_list = [CausalTestCase(self.base_test_case,
65-
self.expected_causal_effect,
66-
0,
67-
1)]
71+
test_list = [CausalTestCase(self.base_test_case, self.expected_causal_effect, 0, 1)]
6872
estimators = [LinearRegressionEstimator]
69-
test_suite.add_test_object(base_test_case=self.base_test_case,
70-
causal_test_case_list=test_list,
71-
estimators=estimators)
73+
test_suite.add_test_object(
74+
base_test_case=self.base_test_case, causal_test_case_list=test_list, estimators=estimators
75+
)
7276
manual_test_object = {
73-
self.base_test_case: {"tests": test_list, "estimators": estimators, "estimate_type": "ate"}}
77+
self.base_test_case: {"tests": test_list, "estimators": estimators, "estimate_type": "ate"}
78+
}
7479
self.assertEqual(test_suite, manual_test_object)
7580

7681
def test_return_single_test_object(self):
82+
"""Test that a single test case can be returned from the test_suite"""
7783
base_test_case = BaseTestCase(self.A, self.D)
7884

79-
test_list = [CausalTestCase(self.base_test_case,
80-
self.expected_causal_effect,
81-
0,
82-
1)]
85+
test_list = [CausalTestCase(self.base_test_case, self.expected_causal_effect, 0, 1)]
8386
estimators = [LinearRegressionEstimator]
84-
self.test_suite.add_test_object(base_test_case=base_test_case,
85-
causal_test_case_list=test_list,
86-
estimators=estimators)
87+
self.test_suite.add_test_object(
88+
base_test_case=base_test_case, causal_test_case_list=test_list, estimators=estimators
89+
)
8790

8891
manual_test_case = {"tests": test_list, "estimators": estimators, "estimate_type": "ate"}
8992

@@ -97,27 +100,31 @@ def test_execute_test_suite_single_base_test_case(self):
97100

98101
causal_test_results = causal_test_engine.execute_test_suite(test_suite=self.test_suite)
99102
causal_test_case_result = causal_test_results[self.base_test_case]
100-
self.assertAlmostEqual(causal_test_case_result['LinearRegressionEstimator'][0].ate, 4, delta=1e-10)
103+
self.assertAlmostEqual(causal_test_case_result["LinearRegressionEstimator"][0].ate, 4, delta=1e-10)
101104

102105
def test_execute_test_suite_multiple_estimators(self):
106+
"""Check that executing a test suite with multiple estimators returns correct results for the dummy data
107+
for each estimator
108+
"""
103109
estimators = [LinearRegressionEstimator, CausalForestEstimator]
104110
test_suite_2_estimators = CausalTestSuite()
105-
test_list = [CausalTestCase(self.base_test_case,
106-
self.expected_causal_effect,
107-
0,
108-
1)]
109-
test_suite_2_estimators.add_test_object(base_test_case=self.base_test_case,
110-
causal_test_case_list=test_list,
111-
estimators=estimators)
111+
test_list = [CausalTestCase(self.base_test_case, self.expected_causal_effect, 0, 1)]
112+
test_suite_2_estimators.add_test_object(
113+
base_test_case=self.base_test_case, causal_test_case_list=test_list, estimators=estimators
114+
)
112115
causal_test_engine = self.create_causal_test_engine()
113116
causal_test_results = causal_test_engine.execute_test_suite(test_suite=test_suite_2_estimators)
114117
causal_test_case_result = causal_test_results[self.base_test_case]
115-
linear_regression_result = causal_test_case_result['LinearRegressionEstimator'][0]
116-
causal_forrest_result = causal_test_case_result['CausalForestEstimator'][0]
118+
linear_regression_result = causal_test_case_result["LinearRegressionEstimator"][0]
119+
causal_forrest_result = causal_test_case_result["CausalForestEstimator"][0]
117120
self.assertAlmostEqual(linear_regression_result.ate, 4, delta=1e-1)
118121
self.assertAlmostEqual(causal_forrest_result.ate, 4, delta=1e-1)
119122

120123
def create_causal_test_engine(self):
124+
"""
125+
Creating test engine is relatively computationally complex, this function allows for it to
126+
easily be made in only the tests that require it.
127+
"""
121128
causal_specification = CausalSpecification(self.scenario, self.causal_dag)
122129

123130
data_collector = ObservationalDataCollector(self.scenario, self.observational_data_csv_path)

0 commit comments

Comments
 (0)