1
1
"""This module contains the CausalTestSuite class, for details on using it:
2
2
https://causal-testing-framework.readthedocs.io/en/latest/test_suite.html"""
3
+ import logging
4
+
3
5
from collections import UserDict
4
6
from typing import Type , Iterable
5
7
from causal_testing .testing .base_test_case import BaseTestCase
6
8
from causal_testing .testing .causal_test_case import CausalTestCase
7
9
from causal_testing .testing .estimators import Estimator
10
+ from causal_testing .testing .causal_test_result import CausalTestResult , TestValue
11
+ from causal_testing .data_collection .data_collector import ObservationalDataCollector
12
+ from causal_testing .specification .causal_specification import CausalSpecification
13
+
14
+ logger = logging .getLogger (__name__ )
8
15
9
16
10
17
class CausalTestSuite (UserDict ):
@@ -20,11 +27,11 @@ class CausalTestSuite(UserDict):
20
27
"""
21
28
22
29
def add_test_object (
23
- self ,
24
- base_test_case : BaseTestCase ,
25
- causal_test_case_list : Iterable [CausalTestCase ],
26
- estimators_classes : Iterable [Type [Estimator ]],
27
- estimate_type : str = "ate" ,
30
+ self ,
31
+ base_test_case : BaseTestCase ,
32
+ causal_test_case_list : Iterable [CausalTestCase ],
33
+ estimators_classes : Iterable [Type [Estimator ]],
34
+ estimate_type : str = "ate" ,
28
35
):
29
36
"""
30
37
A setter object to allow for the easy construction of the dictionary test suite structure
@@ -37,3 +44,98 @@ def add_test_object(
37
44
"""
38
45
test_object = {"tests" : causal_test_case_list , "estimators" : estimators_classes , "estimate_type" : estimate_type }
39
46
self .data [base_test_case ] = test_object
47
+
48
+ def execute_test_suite (self , data_collector : ObservationalDataCollector ,
49
+ causal_specification : CausalSpecification ) -> list [CausalTestResult ]:
50
+ """Execute a suite of causal tests and return the results in a list
51
+ :param test_suite: CasualTestSuite object
52
+ :return: A dictionary where each key is the name of the estimators specified and the values are lists of
53
+ causal_test_result objects
54
+ """
55
+ if data_collector .data .empty :
56
+ raise ValueError ("No data has been loaded. Please call load_data prior to executing a causal test case." )
57
+ data_collector .collect_data ()
58
+ test_suite_results = {}
59
+ for edge in self :
60
+ logger .info ("treatment: %s" , edge .treatment_variable )
61
+ logger .info ("outcome: %s" , edge .outcome_variable )
62
+ minimal_adjustment_set = causal_specification .causal_dag .identification (edge )
63
+ minimal_adjustment_set = minimal_adjustment_set - set (edge .treatment_variable .name )
64
+ minimal_adjustment_set = minimal_adjustment_set - set (edge .outcome_variable .name )
65
+
66
+ variables_for_positivity = list (minimal_adjustment_set ) + [
67
+ edge .treatment_variable .name ,
68
+ edge .outcome_variable .name ,
69
+ ]
70
+
71
+ if self ._check_positivity_violation (variables_for_positivity , causal_specification .scenario , data_collector .data ):
72
+ raise ValueError ("POSITIVITY VIOLATION -- Cannot proceed." )
73
+
74
+ estimators = self [edge ]["estimators" ]
75
+ tests = self [edge ]["tests" ]
76
+ results = {}
77
+ for estimator_class in estimators :
78
+ causal_test_results = []
79
+
80
+ for test in tests :
81
+ estimator = estimator_class (
82
+ test .treatment_variable .name ,
83
+ test .treatment_value ,
84
+ test .control_value ,
85
+ minimal_adjustment_set ,
86
+ test .outcome_variable .name ,
87
+ )
88
+ if estimator .df is None :
89
+ estimator .df = data_collector .data
90
+ causal_test_result = self ._return_causal_test_results (estimator , test )
91
+ causal_test_results .append (causal_test_result )
92
+
93
+ results [estimator_class .__name__ ] = causal_test_results
94
+ test_suite_results [edge ] = results
95
+ return test_suite_results
96
+
97
+ def _return_causal_test_results (self , estimator , causal_test_case ):
98
+ """Depending on the estimator used, calculate the 95% confidence intervals and return in a causal_test_result
99
+
100
+ :param estimator: An Estimator class object
101
+ :param causal_test_case: The concrete test case to be executed
102
+ :return: a CausalTestResult object containing the confidence intervals
103
+ """
104
+ if not hasattr (estimator , f"estimate_{ causal_test_case .estimate_type } " ):
105
+ raise AttributeError (f"{ estimator .__class__ } has no { causal_test_case .estimate_type } method." )
106
+ estimate_effect = getattr (estimator , f"estimate_{ causal_test_case .estimate_type } " )
107
+ effect , confidence_intervals = estimate_effect (** causal_test_case .estimate_params )
108
+ causal_test_result = CausalTestResult (
109
+ estimator = estimator ,
110
+ test_value = TestValue (causal_test_case .estimate_type , effect ),
111
+ effect_modifier_configuration = causal_test_case .effect_modifier_configuration ,
112
+ confidence_intervals = confidence_intervals ,
113
+ )
114
+
115
+ return causal_test_result
116
+
117
+ def _check_positivity_violation (self , variables_list , scenario , data ):
118
+ """Check whether the dataframe has a positivity violation relative to the specified variables list.
119
+
120
+ A positivity violation occurs when there is a stratum of the dataframe which does not have any data. Put simply,
121
+ if we split the dataframe into covariate sub-groups, each sub-group must contain both a treated and untreated
122
+ individual. If a positivity violation occurs, causal inference is still possible using a properly specified
123
+ parametric estimator. Therefore, we should not throw an exception upon violation but raise a warning instead.
124
+
125
+ :param variables_list: The list of variables for which positivity must be satisfied.
126
+ :return: True if positivity is violated, False otherwise.
127
+ """
128
+ if not (set (variables_list ) - {x .name for x in scenario .hidden_variables ()}).issubset (
129
+ data .columns
130
+ ):
131
+ missing_variables = set (variables_list ) - set (data .columns )
132
+ logger .warning (
133
+ "Positivity violation: missing data for variables %s.\n "
134
+ "Causal inference is only valid if a well-specified parametric model is used.\n "
135
+ "Alternatively, consider restricting analysis to executions without the variables:"
136
+ "." ,
137
+ missing_variables ,
138
+ )
139
+ return True
140
+
141
+ return False
0 commit comments