Skip to content

Commit 6b6bfe4

Browse files
Refactor BaseCausalTest to BaseTestCase
1 parent ed9c402 commit 6b6bfe4

File tree

9 files changed

+27
-28
lines changed

9 files changed

+27
-28
lines changed

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from causal_testing.specification.variable import Variable
1010
from causal_testing.testing.causal_test_case import CausalTestCase
1111
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
12-
from causal_testing.testing.base_causal_test import BaseCausalTest
12+
from causal_testing.testing.base_test_case import BaseTestCase
1313

1414
logger = logging.getLogger(__name__)
1515

@@ -110,10 +110,10 @@ def _generate_concrete_tests(
110110
)
111111
model = optimizer.model()
112112

113-
base_causal_test = BaseCausalTest(self.treatment_variables, list(self.expected_causal_effect.keys())[0])
113+
base_test_case = BaseTestCase(self.treatment_variables, list(self.expected_causal_effect.keys())[0])
114114

115115
concrete_test = CausalTestCase(
116-
base_causal_test=base_causal_test,
116+
base_test_case=base_test_case,
117117
control_value=self.treatment_variables.cast(model[self.treatment_variables.z3]),
118118
treatment_value=self.treatment_variables.cast(
119119
model[self.scenario.treatment_variables[self.treatment_variables.name].z3]

causal_testing/json_front/json_class.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
205205
"""
206206
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data_path)
207207
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
208-
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_causal_test)
208+
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
209209
treatment_vars = list(causal_test_case.treatment_input_configuration)
210210
minimal_adjustment_set = minimal_adjustment_set - {v.name for v in treatment_vars}
211211
estimation_model = estimator(

causal_testing/testing/base_causal_test.py renamed to causal_testing/testing/base_test_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
@dataclass(frozen=True)
6-
class BaseCausalTest:
6+
class BaseTestCase:
77
"""
88
A base causal test case represents the relationship of an edge on a causal DAG.
99
"""

causal_testing/testing/causal_test_case.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from causal_testing.specification.variable import Variable
55
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
6-
from causal_testing.testing.base_causal_test import BaseCausalTest
6+
from causal_testing.testing.base_test_case import BaseTestCase
77

88
logger = logging.getLogger(__name__)
99

@@ -18,7 +18,7 @@ class CausalTestCase:
1818

1919
def __init__(
2020
self,
21-
base_causal_test: BaseCausalTest,
21+
base_test_case: BaseTestCase,
2222
expected_causal_effect: CausalTestOutcome,
2323
control_value: Any,
2424
treatment_value: Any = None,
@@ -36,13 +36,13 @@ def __init__(
3636
:param treatment_input_configuration: The input configuration representing the treatment values of the treatment
3737
variables. That is, the input configuration *after* applying the intervention.
3838
"""
39-
self.base_causal_test = base_causal_test
40-
self.control_input_configuration = {base_causal_test.treatment_variable: control_value}
39+
self.base_test_case = base_test_case
40+
self.control_input_configuration = {base_test_case.treatment_variable: control_value}
4141
self.expected_causal_effect = expected_causal_effect
42-
self.outcome_variable = base_causal_test.outcome_variable
43-
self.treatment_input_configuration = {base_causal_test.treatment_variable: treatment_value}
42+
self.outcome_variable = base_test_case.outcome_variable
43+
self.treatment_input_configuration = {base_test_case.treatment_variable: treatment_value}
4444
self.estimate_type = estimate_type
45-
self.effect = base_causal_test.effect
45+
self.effect = base_test_case.effect
4646
if effect_modifier_configuration:
4747
self.effect_modifier_configuration = effect_modifier_configuration
4848
else:

causal_testing/testing/causal_test_engine.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from causal_testing.testing.causal_test_case import CausalTestCase
88
from causal_testing.testing.causal_test_outcome import CausalTestResult
99
from causal_testing.testing.estimators import Estimator
10-
from causal_testing.testing.base_causal_test import BaseCausalTest
10+
from causal_testing.testing.base_test_case import BaseTestCase
1111
from causal_testing.testing.causal_test_suite import CausalTestSuite
1212

1313
logger = logging.getLogger(__name__)
@@ -53,8 +53,7 @@ def __init__(self, causal_specification: CausalSpecification, data_collector: Da
5353
self.data_collector = data_collector
5454
self.scenario_execution_data_df = self.data_collector.collect_data(**kwargs)
5555

56-
def execute_test_suite(
57-
self, test_suite: CausalTestSuite) -> list[CausalTestResult]:
56+
def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResult]:
5857

5958
"""Execute a suite of causal tests and return the results in a list
6059
:param test_suite: CasualTestSuite object
@@ -75,7 +74,7 @@ def execute_test_suite(
7574
minimal_adjustment_set = minimal_adjustment_set - set(edge.outcome_variable.name)
7675

7776
variables_for_positivity = (
78-
list(minimal_adjustment_set) + [edge.treatment_variable.name] + [edge.outcome_variable.name]
77+
list(minimal_adjustment_set) + [edge.treatment_variable.name] + [edge.outcome_variable.name]
7978
)
8079
if self._check_positivity_violation(variables_for_positivity):
8180
# TODO: We should allow users to continue because positivity can be overcome with parametric models
@@ -110,7 +109,7 @@ def execute_test_suite(
110109
return test_suite_results
111110

112111
def execute_test(
113-
self, estimator: type(Estimator), causal_test_case: CausalTestCase, estimate_type: str = "ate"
112+
self, estimator: type(Estimator), causal_test_case: CausalTestCase, estimate_type: str = "ate"
114113
) -> CausalTestResult:
115114
"""Execute a causal test case and return the causal test result.
116115
@@ -139,7 +138,7 @@ def execute_test(
139138

140139
logger.info("treatments: %s", treatments)
141140
logger.info("outcomes: %s", outcome_variable)
142-
minimal_adjustment_set = self.causal_dag.identification(BaseCausalTest(treatment_variable, outcome_variable))
141+
minimal_adjustment_set = self.causal_dag.identification(BaseTestCase(treatment_variable, outcome_variable))
143142
minimal_adjustment_set = minimal_adjustment_set - {v.name for v in causal_test_case.control_input_configuration}
144143
minimal_adjustment_set = minimal_adjustment_set - set(outcome_variable.name)
145144
assert all(

causal_testing/testing/causal_test_suite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ class CausalTestSuite:
99
"""
1010

1111
def __init__(
12-
self,
12+
self,
1313
):
1414
self.test_suite = {}
1515

1616
def add_test_object(self, base_test_case, causal_test_case_list, estimators, estimate_type):
17-
test_object = {'tests': causal_test_case_list, 'estimators': estimators, 'estimate_type': estimate_type}
17+
test_object = {"tests": causal_test_case_list, "estimators": estimators, "estimate_type": estimate_type}
1818
self.test_suite[base_test_case] = test_object
1919

2020
def get_single_test_object(self, base_test_case):

examples/lr91/causal_test_max_conductances.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from causal_testing.testing.causal_test_outcome import Positive, Negative, NoEffect
1111
from causal_testing.testing.causal_test_engine import CausalTestEngine
1212
from causal_testing.testing.estimators import LinearRegressionEstimator
13-
from causal_testing.testing.base_causal_test import BaseCausalTest
13+
from causal_testing.testing.base_test_case import BaseTestCase
1414
from matplotlib.pyplot import rcParams
1515

1616
# Uncommenting the code below will make all graphs publication quality but requires a suitable latex installation
@@ -110,9 +110,9 @@ def effects_on_APD90(observational_data_path, treatment_var, control_val, treatm
110110

111111
# 5. Create a causal specification from the scenario and causal DAG
112112
causal_specification = CausalSpecification(scenario, causal_dag)
113-
base_test_case = BaseCausalTest(treatment_var, apd90)
113+
base_test_case = BaseTestCase(treatment_var, apd90)
114114
# 6. Create a causal test case
115-
causal_test_case = CausalTestCase(base_causal_test=base_test_case,
115+
causal_test_case = CausalTestCase(base_test_case=base_test_case,
116116
expected_causal_effect=expected_causal_effect,
117117
control_value=control_val,
118118
treatment_value=treatment_val)

examples/lr91/causal_test_max_conductances_test_suite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from causal_testing.testing.causal_test_outcome import Positive, Negative, NoEffect
1111
from causal_testing.testing.causal_test_engine import CausalTestEngine
1212
from causal_testing.testing.estimators import LinearRegressionEstimator
13-
from causal_testing.testing.base_causal_test import BaseCausalTest
13+
from causal_testing.testing.base_test_case import BaseTestCase
1414
from causal_testing.testing.causal_test_suite import CausalTestSuite
1515
from matplotlib.pyplot import rcParams
1616

@@ -63,7 +63,7 @@ def causal_testing_sensitivity_analysis():
6363

6464
for conductance_param, mean_and_oracle in conductance_means.items():
6565
treatment_variable = Input(conductance_param, float)
66-
base_test_case = BaseCausalTest(treatment_variable, outcome_variable)
66+
base_test_case = BaseTestCase(treatment_variable, outcome_variable)
6767
test_list = []
6868
control_value = 0.5
6969
mean, oracle = mean_and_oracle

tests/testing_tests/test_causal_test_case.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from causal_testing.specification.causal_dag import CausalDAG
77
from causal_testing.testing.causal_test_case import CausalTestCase
88
from causal_testing.testing.causal_test_outcome import ExactValue
9-
from causal_testing.testing.base_causal_test import BaseCausalTest
9+
from causal_testing.testing.base_test_case import BaseTestCase
1010

1111
class TestCausalTestEngineObservational(unittest.TestCase):
1212
""" Test the CausalTestEngine workflow using observational data.
@@ -34,9 +34,9 @@ def setUp(self) -> None:
3434

3535
# 3. Create an intervention and causal test case
3636
self.expected_causal_effect = ExactValue(4)
37-
self.base_causal_test = BaseCausalTest(A, C)
37+
self.base_test_case = BaseTestCase(A, C)
3838
self.causal_test_case = CausalTestCase(
39-
base_causal_test=self.base_causal_test,
39+
base_test_case=self.base_test_case,
4040
expected_causal_effect=self.expected_causal_effect,
4141
control_value=0,
4242
treatment_value=1

0 commit comments

Comments
 (0)