@@ -32,6 +32,7 @@ def __init__(
32
32
estimate_type : str = "ate" ,
33
33
estimate_params : dict = None ,
34
34
effect_modifier_configuration : dict [Variable :Any ] = None ,
35
+ estimator : type (Estimator ) = None ,
35
36
):
36
37
"""
37
38
:param base_test_case: A BaseTestCase object consisting of a treatment variable, outcome variable and effect
@@ -40,6 +41,7 @@ def __init__(
40
41
:param treatment_value: The treatment value for the treatment variable (after intervention).
41
42
:param estimate_type: A string which denotes the type of estimate to return
42
43
:param effect_modifier_configuration:
44
+ :param estimator: An Estimator class instance
43
45
"""
44
46
self .base_test_case = base_test_case
45
47
self .control_value = control_value
@@ -48,6 +50,7 @@ def __init__(
48
50
self .treatment_variable = base_test_case .treatment_variable
49
51
self .treatment_value = treatment_value
50
52
self .estimate_type = estimate_type
53
+ self .estimator = estimator
51
54
if estimate_params is None :
52
55
self .estimate_params = {}
53
56
self .effect = base_test_case .effect
@@ -57,19 +60,18 @@ def __init__(
57
60
else :
58
61
self .effect_modifier_configuration = {}
59
62
60
- def execute_test (self , estimator : type ( Estimator ) ) -> CausalTestResult :
63
+ def execute_test (self ) -> CausalTestResult :
61
64
"""Execute a causal test case and return the causal test result.
62
65
63
- :param estimator: An Estimator class object
64
66
:return causal_test_result: A CausalTestResult for the executed causal test case.
65
67
"""
66
68
67
- if not hasattr (estimator , f"estimate_{ self .estimate_type } " ):
68
- raise AttributeError (f"{ estimator .__class__ } has no { self .estimate_type } method." )
69
- estimate_effect = getattr (estimator , f"estimate_{ self .estimate_type } " )
69
+ if not hasattr (self . estimator , f"estimate_{ self .estimate_type } " ):
70
+ raise AttributeError (f"{ self . estimator .__class__ } has no { self .estimate_type } method." )
71
+ estimate_effect = getattr (self . estimator , f"estimate_{ self .estimate_type } " )
70
72
effect , confidence_intervals = estimate_effect (** self .estimate_params )
71
73
return CausalTestResult (
72
- estimator = estimator ,
74
+ estimator = self . estimator ,
73
75
test_value = TestValue (self .estimate_type , effect ),
74
76
effect_modifier_configuration = self .effect_modifier_configuration ,
75
77
confidence_intervals = confidence_intervals ,
0 commit comments