Skip to content

Commit 08df6ae

Browse files
committed
Moved effect type to causal test case rather than causal test engine
1 parent 5b0dced commit 08df6ae

File tree

3 files changed

+22
-11
lines changed

3 files changed

+22
-11
lines changed

causal_testing/testing/causal_test_case.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class CausalTestCase:
1818

1919
def __init__(self, control_input_configuration: dict[Variable: Any], expected_causal_effect: CausalTestOutcome,
2020
outcome_variables: dict[Variable], treatment_input_configuration: dict[Variable: Any] = None,
21-
estimate_type: str = "ate", effect_modifier_configuration: dict[Variable: Any] = None):
21+
estimate_type: str = "ate", effect_modifier_configuration: dict[Variable: Any] = None, effect: str = "total"):
2222
"""
2323
When a CausalTestCase is initialised, it takes the intervention and applies it to the input configuration to
2424
create two distinct input configurations: a control input configuration and a treatment input configuration.
@@ -35,6 +35,7 @@ def __init__(self, control_input_configuration: dict[Variable: Any], expected_ca
3535
self.outcome_variables = outcome_variables
3636
self.treatment_input_configuration = treatment_input_configuration
3737
self.estimate_type = estimate_type
38+
self.effect = effect
3839
if effect_modifier_configuration:
3940
self.effect_modifier_configuration = effect_modifier_configuration
4041
else:

causal_testing/testing/causal_test_engine.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,12 @@ class CausalTestEngine:
3030
"""
3131

3232
def __init__(self, causal_test_case: CausalTestCase, causal_specification: CausalSpecification,
33-
data_collector: DataCollector, effect: str = "total"):
33+
data_collector: DataCollector):
3434
self.causal_test_case = causal_test_case
3535
self.treatment_variables = list(self.causal_test_case.control_input_configuration)
3636
self.casual_dag, self.scenario = causal_specification.causal_dag, causal_specification.scenario
3737
self.data_collector = data_collector
3838
self.scenario_execution_data_df = pd.DataFrame()
39-
self.effect = effect
4039

4140
def load_data(self, **kwargs):
4241
""" Load execution data corresponding to the causal test case into a pandas dataframe and return the minimal
@@ -61,12 +60,12 @@ def load_data(self, **kwargs):
6160
self.scenario_execution_data_df = self.data_collector.collect_data(**kwargs)
6261

6362
minimal_adjustment_sets = []
64-
if self.effect == "total":
63+
if self.causal_test_case.effect == "total":
6564
minimal_adjustment_sets = self.casual_dag.enumerate_minimal_adjustment_sets(
6665
[v.name for v in self.treatment_variables],
6766
[v.name for v in self.causal_test_case.outcome_variables]
6867
)
69-
elif self.effect == "direct":
68+
elif self.causal_test_case.effect == "direct":
7069
minimal_adjustment_sets = self.casual_dag.direct_effect_adjustment_sets(
7170
[v.name for v in self.treatment_variables],
7271
[v.name for v in self.causal_test_case.outcome_variables]

tests/testing_tests/test_causal_test_engine.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,17 @@ def test_execute_test_observational_causal_forest_estimator(self):
145145
def test_invalid_causal_effect(self):
146146
""" Check that executing the causal test case returns the correct results for dummy data using a linear
147147
regression estimator. """
148+
causal_test_case = CausalTestCase(
149+
control_input_configuration={self.A: 0},
150+
expected_causal_effect=self.expected_causal_effect,
151+
treatment_input_configuration={self.A: 1},
152+
outcome_variables={self.C},
153+
effect="error")
148154
# 5. Create causal test engine
149155
causal_test_engine = CausalTestEngine(
150-
self.causal_test_case,
156+
causal_test_case,
151157
self.causal_specification,
152-
self.data_collector,
153-
effect="error"
158+
self.data_collector
154159
)
155160
with self.assertRaises(Exception):
156161
causal_test_engine.load_data()
@@ -172,12 +177,18 @@ def test_execute_test_observational_linear_regression_estimator(self):
172177
def test_execute_test_observational_linear_regression_estimator_direct_effect(self):
173178
""" Check that executing the causal test case returns the correct results for dummy data using a linear
174179
regression estimator. """
180+
causal_test_case = CausalTestCase(
181+
control_input_configuration={self.A: 0},
182+
expected_causal_effect=self.expected_causal_effect,
183+
treatment_input_configuration={self.A: 1},
184+
outcome_variables={self.C},
185+
effect="direct")
186+
175187
# 5. Create causal test engine
176188
causal_test_engine = CausalTestEngine(
177-
self.causal_test_case,
189+
causal_test_case,
178190
self.causal_specification,
179-
self.data_collector,
180-
effect="direct"
191+
self.data_collector
181192
)
182193
self.minimal_adjustment_set = causal_test_engine.load_data()
183194

0 commit comments

Comments
 (0)