Skip to content

Commit db7b216

Browse files
Update test_causal_test_engine.py tests
1 parent 6b6bfe4 commit db7b216

File tree

1 file changed

+24
-28
lines changed

1 file changed

+24
-28
lines changed

tests/testing_tests/test_causal_test_engine.py

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from causal_testing.testing.causal_test_engine import CausalTestEngine
1212
from causal_testing.testing.causal_test_outcome import ExactValue
1313
from causal_testing.testing.estimators import CausalForestEstimator, LinearRegressionEstimator
14-
14+
from causal_testing.testing.base_test_case import BaseTestCase
1515

1616
class TestCausalTestEngineObservational(unittest.TestCase):
1717
""" Test the CausalTestEngine workflow using observational data.
@@ -41,11 +41,12 @@ def setUp(self) -> None:
4141

4242
# 3. Create a causal test case
4343
self.expected_causal_effect = ExactValue(4)
44+
self.base_test_case = BaseTestCase(A, C)
4445
self.causal_test_case = CausalTestCase(
45-
control_input_configuration={A: 0},
46+
base_test_case=self.base_test_case,
4647
expected_causal_effect=self.expected_causal_effect,
47-
treatment_input_configuration={A: 1},
48-
outcome_variables={C})
48+
control_value=0,
49+
treatment_value=1)
4950

5051
# 4. Create dummy test data and write to csv
5152
np.random.seed(1)
@@ -64,9 +65,7 @@ def setUp(self) -> None:
6465
self.causal_specification,
6566
self.data_collector
6667
)
67-
68-
self.causal_test_engine.identification(self.causal_test_case)
69-
self.minimal_adjustment_set = self.causal_test_engine.minimal_adjustment_set
68+
self.minimal_adjustment_set = self.causal_dag.identification(self.base_test_case)
7069
# 6. Easier to access treatment and outcome values
7170
self.treatment_value = 1
7271
self.control_value = 0
@@ -114,9 +113,8 @@ def test_check_positivity_violation_missing_outcome(self):
114113

115114
def test_check_minimum_adjustment_set(self):
116115
""" Check that the minimum adjustment set is correctly made"""
117-
self.causal_test_engine.identification(self.causal_test_case)
118-
minimum_adjustment_set = self.causal_test_engine.minimal_adjustment_set
119-
self.assertEqual(minimum_adjustment_set, {'D'})
116+
minimal_adjustment_set = self.causal_dag.identification(self.base_test_case)
117+
self.assertEqual(minimal_adjustment_set, {'D'})
120118

121119
def test_execute_test_observational_causal_forest_estimator(self):
122120
""" Check that executing the causal test case returns the correct results for the dummy data using a causal
@@ -134,19 +132,14 @@ def test_execute_test_observational_causal_forest_estimator(self):
134132
def test_invalid_causal_effect(self):
135133
""" Check that executing the causal test case returns the correct results for dummy data using a linear
136134
regression estimator. """
137-
causal_test_case = CausalTestCase(
138-
control_input_configuration={self.A: 0},
139-
expected_causal_effect=self.expected_causal_effect,
140-
treatment_input_configuration={self.A: 1},
141-
outcome_variables={self.C},
142-
effect="error")
143-
# 5. Create causal test engine
144-
causal_test_engine = CausalTestEngine(
145-
self.causal_specification,
146-
self.data_collector
135+
base_test_case = BaseTestCase(
136+
treatment_variable=self.A,
137+
outcome_variable=self.C,
138+
effect="error"
147139
)
140+
148141
with self.assertRaises(Exception):
149-
causal_test_engine.identification()
142+
self.causal_dag.identification(base_test_case)
150143

151144

152145
def test_execute_test_observational_linear_regression_estimator(self):
@@ -165,27 +158,30 @@ def test_execute_test_observational_linear_regression_estimator(self):
165158
def test_execute_test_observational_linear_regression_estimator_direct_effect(self):
166159
""" Check that executing the causal test case returns the correct results for dummy data using a linear
167160
regression estimator. """
161+
base_test_case = BaseTestCase(
162+
treatment_variable=self.A,
163+
outcome_variable=self.C,
164+
effect="direct")
165+
168166
causal_test_case = CausalTestCase(
169-
control_input_configuration={self.A: 0},
167+
base_test_case=base_test_case,
170168
expected_causal_effect=self.expected_causal_effect,
171-
treatment_input_configuration={self.A: 1},
172-
outcome_variables={self.C},
173-
effect="direct")
169+
control_value=0,
170+
treatment_value=1)
174171

175172
# 5. Create causal test engine
176173
causal_test_engine = CausalTestEngine(
177174
self.causal_specification,
178175
self.data_collector
179176
)
180-
causal_test_engine.identification(causal_test_case)
181-
self.minimal_adjustment_set = causal_test_engine.minimal_adjustment_set
177+
minimal_adjustment_set = self.causal_dag.identification(base_test_case)
182178
# 6. Easier to access treatment and outcome values
183179
self.treatment_value = 1
184180
self.control_value = 0
185181
estimation_model = LinearRegressionEstimator(('A',),
186182
self.treatment_value,
187183
self.control_value,
188-
self.minimal_adjustment_set,
184+
minimal_adjustment_set,
189185
('C',),
190186
causal_test_engine.scenario_execution_data_df)
191187
causal_test_result = causal_test_engine.execute_test(estimation_model, causal_test_case)

0 commit comments

Comments
 (0)