Skip to content

Commit d53fef2

Browse files
committed
No mention of CausalTestEngine
1 parent 7620f96 commit d53fef2

File tree

4 files changed

+11
-20
lines changed

4 files changed

+11
-20
lines changed

causal_testing/testing/causal_test_adequacy.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
This module contains code to measure various aspects of causal test adequacy.
33
"""
44
from causal_testing.testing.causal_test_suite import CausalTestSuite
5+
from causal_testing.data_collection.data_collector import DataCollector
56
from causal_testing.specification.causal_specification import CausalSpecification
67
from causal_testing.testing.estimators import Estimator
78
from causal_testing.testing.causal_test_case import CausalTestCase
8-
from causal_testing.testing.causal_test_engine import CausalTestEngine
99
from itertools import combinations
1010
from copy import deepcopy
1111
from sklearn.model_selection import KFold
@@ -35,11 +35,11 @@ def measure_adequacy(self):
3535

3636
class DataAdequacy:
3737
def __init__(
38-
self, test_case: CausalTestCase, test_engine: CausalTestEngine, estimator: Estimator, bootstrap_size: int = 100
38+
self, test_case: CausalTestCase, estimator: Estimator, data_collector: DataCollector, bootstrap_size: int = 100
3939
):
4040
self.test_case = test_case
41-
self.test_engine = test_engine
4241
self.estimator = estimator
42+
self.data_collector = data_collector
4343
self.kurtosis = None
4444
self.outcomes = None
4545
self.bootstrap_size = bootstrap_size
@@ -50,7 +50,7 @@ def measure_adequacy(self):
5050
estimator = deepcopy(self.estimator)
5151
estimator.df = estimator.df.sample(len(estimator.df), replace=True, random_state=i)
5252
try:
53-
results.append(self.test_engine.execute_test(estimator, self.test_case))
53+
results.append(self.test.execute_test(estimator, self.data_collector))
5454
except np.LinAlgError:
5555
continue
5656
outcomes = [self.test_case.expected_causal_effect.apply(c) for c in results]

causal_testing/testing/causal_test_case.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,8 @@ class CausalTestCase:
1818
"""
1919
A CausalTestCase extends the information held in a BaseTestCase. As well as storing the treatment and outcome
2020
variables, a CausalTestCase stores the values of these variables. Also the outcome variable and value are
21-
specified.
22-
23-
The goal of a CausalTestCase is to test whether the intervention made to the control via the treatment causes the
24-
model-under-test to produce the expected change. The CausalTestCase structure is designed for execution using the
25-
CausalTestEngine, using either execute_test() function to execute a single test case or packing CausalTestCases into
26-
a CausalTestSuite and executing them as a batch using the execute_test_suite() function.
21+
specified. The goal of a CausalTestCase is to test whether the intervention made to the control via the treatment causes the
22+
model-under-test to produce the expected change.
2723
"""
2824

2925
def __init__(

causal_testing/testing/causal_test_suite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class CausalTestSuite(UserDict):
2222
and the values are test objects. Test Objects hold a causal_test_case_list which is a list of causal_test_cases
2323
which provide control and treatment values, and an iterator of Estimator Class References
2424
25-
This dictionary can be fed into the CausalTestEngines execute_test_suite function which will iterate over all the
25+
This dictionary can be fed into the execute_test_suite function which will iterate over all the
2626
base_test_case's and execute each causal_test_case with each iterator.
2727
"""
2828

tests/testing_tests/test_causal_test_case.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,8 @@ def tearDown(self) -> None:
6060
remove_temp_dir_if_existent()
6161

6262

63-
6463
class TestCausalTestExecution(unittest.TestCase):
65-
"""Test the CausalTestEngine workflow using observational data.
64+
"""Test the causal test execution workflow using observational data.
6665
6766
The causal test engine (CTE) is the main workflow for the causal testing framework. The CTE takes a causal test case
6867
and a causal specification and computes the causal effect of the intervention on the outcome of interest.
@@ -246,12 +245,8 @@ def test_execute_observational_causal_forest_estimator_cates(self):
246245
"""Check that executing the causal test case returns the correct conditional average treatment effects for
247246
dummy data with effect multiplicative effect modification. C ~ (4*(A+2) + D)*M"""
248247
# Add some effect modifier M that has a multiplicative effect on C
249-
self.df["M"] = np.random.randint(
250-
1, 5, len(self.df)
251-
)
252-
self.df["C"] *= self.df[
253-
"M"
254-
]
248+
self.df["M"] = np.random.randint(1, 5, len(self.df))
249+
self.df["C"] *= self.df["M"]
255250
estimation_model = CausalForestEstimator(
256251
"A",
257252
self.treatment_value,
@@ -274,4 +269,4 @@ def test_execute_observational_causal_forest_estimator_cates(self):
274269
self.assertLess(causal_test_result_m3["cate"].mean(), causal_test_result_m4["cate"].mean())
275270

276271
def tearDown(self) -> None:
277-
remove_temp_dir_if_existent()
272+
remove_temp_dir_if_existent()

0 commit comments

Comments
 (0)