Skip to content

Commit daa6b10

Browse files
Refactor test_suite to be a UserDict class
1 parent b01ab8a commit daa6b10

File tree

3 files changed

+15
-19
lines changed

3 files changed

+15
-19
lines changed

causal_testing/testing/causal_test_engine.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
6363
if self.scenario_execution_data_df.empty:
6464
raise Exception("No data has been loaded. Please call load_data prior to executing a causal test case.")
6565
test_suite_results = {}
66-
test_suite_dict = test_suite.test_suite
67-
for edge in test_suite_dict:
66+
for edge in test_suite:
6867
print("edge: ")
6968
print(edge)
7069
logger.info("treatment: %s", edge.treatment_variable)
@@ -81,9 +80,9 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
8180
# TODO: When we implement causal contracts, we should also note the positivity violation there
8281
raise Exception("POSITIVITY VIOLATION -- Cannot proceed.")
8382

84-
estimators = test_suite_dict[edge]["estimators"]
85-
tests = test_suite_dict[edge]["tests"]
86-
estimate_type = test_suite_dict[edge]["estimate_type"]
83+
estimators = test_suite[edge]["estimators"]
84+
tests = test_suite[edge]["tests"]
85+
estimate_type = test_suite[edge]["estimate_type"]
8786
results = {}
8887
for EstimatorClass in estimators:
8988
causal_test_results = []
Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
class CausalTestSuite:
1+
from collections import UserDict
2+
3+
4+
class CausalTestSuite(UserDict):
25
"""
36
A CausalTestSuite is a class holding a dictionary, where the keys are base_test_cases representing the treatment
47
and outcome Variables, and the values are test objects. Test Objects hold a causal_test_case_list which is a
@@ -8,14 +11,6 @@ class CausalTestSuite:
811
base_test_case's and execute each causal_test_case with each iterator.
912
"""
1013

11-
def __init__(
12-
self,
13-
):
14-
self.test_suite = {}
15-
1614
def add_test_object(self, base_test_case, causal_test_case_list, estimators, estimate_type: str = "ate"):
1715
test_object = {"tests": causal_test_case_list, "estimators": estimators, "estimate_type": estimate_type}
18-
self.test_suite[base_test_case] = test_object
19-
20-
def get_single_test_object(self, base_test_case):
21-
return self.test_suite[base_test_case]
16+
self.data[base_test_case] = test_object

tests/testing_tests/test_causal_test_suite.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_adding_test_object(self):
7171
estimators=estimators)
7272
manual_test_object = {
7373
self.base_test_case: {"tests": test_list, "estimators": estimators, "estimate_type": "ate"}}
74-
self.assertEqual(test_suite.test_suite, manual_test_object)
74+
self.assertEqual(test_suite, manual_test_object)
7575

7676
def test_return_single_test_object(self):
7777
base_test_case = BaseTestCase(self.A, self.D)
@@ -87,7 +87,7 @@ def test_return_single_test_object(self):
8787

8888
manual_test_case = {"tests": test_list, "estimators": estimators, "estimate_type": "ate"}
8989

90-
test_case = self.test_suite.get_single_test_object(base_test_case)
90+
test_case = self.test_suite[base_test_case]
9191

9292
self.assertEqual(test_case, manual_test_case)
9393

@@ -103,8 +103,10 @@ def test_execute_test_suite_multiple_base_test_cases(self):
103103
pass
104104

105105
def test_execute_test_suite_multiple_estimators(self):
106-
pass
107-
106+
#estimators = [LinearRegressionEstimator, CausalForestEstimator]
107+
#test_suite_2_estimators = CausalTestSuite
108+
#test_suite_2_estimators
109+
causal_test_engine = self.create_causal_test_engine()
108110
def create_causal_test_engine(self):
109111

110112
causal_specification = CausalSpecification(self.scenario, self.causal_dag)

0 commit comments

Comments
 (0)