Skip to content

Commit 78d1c3d

Browse files
committed
causal test case
1 parent 1cef1fb commit 78d1c3d

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

causal_testing/testing/causal_test_case.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
estimate_type: str = "ate",
3333
estimate_params: dict = None,
3434
effect_modifier_configuration: dict[Variable:Any] = None,
35+
estimator: type(Estimator) = None,
3536
):
3637
"""
3738
:param base_test_case: A BaseTestCase object consisting of a treatment variable, outcome variable and effect
@@ -40,6 +41,7 @@ def __init__(
4041
:param treatment_value: The treatment value for the treatment variable (after intervention).
4142
:param estimate_type: A string which denotes the type of estimate to return
4243
:param effect_modifier_configuration:
44+
:param estimator: An Estimator class instance
4345
"""
4446
self.base_test_case = base_test_case
4547
self.control_value = control_value
@@ -48,6 +50,7 @@ def __init__(
4850
self.treatment_variable = base_test_case.treatment_variable
4951
self.treatment_value = treatment_value
5052
self.estimate_type = estimate_type
53+
self.estimator = estimator
5154
if estimate_params is None:
5255
self.estimate_params = {}
5356
self.effect = base_test_case.effect
@@ -57,19 +60,18 @@ def __init__(
5760
else:
5861
self.effect_modifier_configuration = {}
5962

60-
def execute_test(self, estimator: type(Estimator)) -> CausalTestResult:
63+
def execute_test(self) -> CausalTestResult:
6164
"""Execute a causal test case and return the causal test result.
6265
63-
:param estimator: An Estimator class object
6466
:return causal_test_result: A CausalTestResult for the executed causal test case.
6567
"""
6668

67-
if not hasattr(estimator, f"estimate_{self.estimate_type}"):
68-
raise AttributeError(f"{estimator.__class__} has no {self.estimate_type} method.")
69-
estimate_effect = getattr(estimator, f"estimate_{self.estimate_type}")
69+
if not hasattr(self.estimator, f"estimate_{self.estimate_type}"):
70+
raise AttributeError(f"{self.estimator.__class__} has no {self.estimate_type} method.")
71+
estimate_effect = getattr(self.estimator, f"estimate_{self.estimate_type}")
7072
effect, confidence_intervals = estimate_effect(**self.estimate_params)
7173
return CausalTestResult(
72-
estimator=estimator,
74+
estimator=self.estimator,
7375
test_value=TestValue(self.estimate_type, effect),
7476
effect_modifier_configuration=self.effect_modifier_configuration,
7577
confidence_intervals=confidence_intervals,

0 commit comments

Comments
 (0)