4
4
import pandas as pd
5
5
import z3
6
6
from scipy import stats
7
+ import itertools
7
8
8
9
from causal_testing .specification .scenario import Scenario
9
10
from causal_testing .specification .variable import Variable
10
11
from causal_testing .testing .causal_test_case import CausalTestCase
11
12
from causal_testing .testing .causal_test_outcome import CausalTestOutcome
13
+ from causal_testing .testing .base_test_case import BaseTestCase
14
+
15
+ from enum import Enum
12
16
13
17
logger = logging .getLogger (__name__ )
14
18
15
19
16
20
class AbstractCausalTestCase :
17
21
"""
18
- An abstract test case serves as a generator for concrete test cases. Instead of having concrete conctrol
22
+ An abstract test case serves as a generator for concrete test cases. Instead of having concrete control
19
23
and treatment values, we instead just specify the intervention and the treatment variables. This then
20
24
enables potentially infinite concrete test cases to be generated between different values of the treatment.
21
25
"""
@@ -24,23 +28,26 @@ def __init__(
24
28
self ,
25
29
scenario : Scenario ,
26
30
intervention_constraints : set [z3 .ExprRef ],
27
- treatment_variables : set [ Variable ] ,
31
+ treatment_variable : Variable ,
28
32
expected_causal_effect : dict [Variable :CausalTestOutcome ],
29
33
effect_modifiers : set [Variable ] = None ,
30
34
estimate_type : str = "ate" ,
35
+ effect : str = "total" ,
31
36
):
32
- assert treatment_variables .issubset (scenario .variables .values ()), (
33
- "Treatment variables must be a subset of variables."
34
- + f" Instead got:\n treatment_variables={ treatment_variables } \n variables={ scenario .variables } "
35
- )
37
+ if treatment_variable not in scenario .variables .values ():
38
+ raise ValueError (
39
+ "Treatment variables must be a subset of variables."
40
+ + f" Instead got:\n treatment_variables={ treatment_variable } \n variables={ scenario .variables } "
41
+ )
36
42
37
43
assert len (expected_causal_effect ) == 1 , "We currently only support tests with one causal outcome"
38
44
39
45
self .scenario = scenario
40
46
self .intervention_constraints = intervention_constraints
41
- self .treatment_variables = treatment_variables
47
+ self .treatment_variable = treatment_variable
42
48
self .expected_causal_effect = expected_causal_effect
43
49
self .estimate_type = estimate_type
50
+ self .effect = effect
44
51
45
52
if effect_modifiers is not None :
46
53
self .effect_modifiers = effect_modifiers
@@ -100,7 +107,12 @@ def _generate_concrete_tests(
100
107
for c in self .intervention_constraints :
101
108
optimizer .assert_and_track (c , str (c ))
102
109
103
- optimizer .add_soft ([self .scenario .variables [v ].z3 == row [v ] for v in run_columns ])
110
+ for v in run_columns :
111
+ optimizer .add_soft (
112
+ self .scenario .variables [v ].z3
113
+ == self .scenario .variables [v ].z3_val (self .scenario .variables [v ].z3 , row [v ])
114
+ )
115
+
104
116
if optimizer .check () == z3 .unsat :
105
117
logger .warning (
106
118
"Satisfiability of test case was unsat.\n " "Constraints \n %s \n Unsat core %s" ,
@@ -109,13 +121,19 @@ def _generate_concrete_tests(
109
121
)
110
122
model = optimizer .model ()
111
123
124
+ base_test_case = BaseTestCase (
125
+ treatment_variable = self .treatment_variable ,
126
+ outcome_variable = list (self .expected_causal_effect .keys ())[0 ],
127
+ effect = self .effect ,
128
+ )
129
+
112
130
concrete_test = CausalTestCase (
113
- control_input_configuration = {v : v .cast (model [v .z3 ]) for v in self .treatment_variables },
114
- treatment_input_configuration = {
115
- v : v .cast (model [self .scenario .treatment_variables [v .name ].z3 ]) for v in self .treatment_variables
116
- },
131
+ base_test_case = base_test_case ,
132
+ control_value = self .treatment_variable .cast (model [self .treatment_variable .z3 ]),
133
+ treatment_value = self .treatment_variable .cast (
134
+ model [self .scenario .treatment_variables [self .treatment_variable .name ].z3 ]
135
+ ),
117
136
expected_causal_effect = list (self .expected_causal_effect .values ())[0 ],
118
- outcome_variables = list (self .expected_causal_effect .keys ()),
119
137
estimate_type = self .estimate_type ,
120
138
effect_modifier_configuration = {v : v .cast (model [v .z3 ]) for v in self .effect_modifiers },
121
139
)
@@ -128,19 +146,20 @@ def _generate_concrete_tests(
128
146
+ f"{ constraints } \n Using value { v .cast (model [v .z3 ])} instead in test\n { concrete_test } "
129
147
)
130
148
131
- concrete_tests .append (concrete_test )
132
- # Control run
133
- control_run = {
134
- v .name : v .cast (model [v .z3 ]) for v in self .scenario .variables .values () if v .name in run_columns
135
- }
136
- control_run ["bin" ] = index
137
- runs .append (control_run )
138
- # Treatment run
139
- if rct :
140
- treatment_run = control_run .copy ()
141
- treatment_run .update ({k .name : v for k , v in concrete_test .treatment_input_configuration .items ()})
142
- treatment_run ["bin" ] = index
143
- runs .append (treatment_run )
149
+ if not any ([vars (t ) == vars (concrete_test ) for t in concrete_tests ]):
150
+ concrete_tests .append (concrete_test )
151
+ # Control run
152
+ control_run = {
153
+ v .name : v .cast (model [v .z3 ]) for v in self .scenario .variables .values () if v .name in run_columns
154
+ }
155
+ control_run ["bin" ] = index
156
+ runs .append (control_run )
157
+ # Treatment run
158
+ if rct :
159
+ treatment_run = control_run .copy ()
160
+ treatment_run .update ({concrete_test .treatment_variable .name : concrete_test .treatment_value })
161
+ treatment_run ["bin" ] = index
162
+ runs .append (treatment_run )
144
163
145
164
return concrete_tests , pd .DataFrame (runs , columns = run_columns + ["bin" ])
146
165
@@ -176,13 +195,16 @@ def generate_concrete_tests(
176
195
runs = pd .DataFrame ()
177
196
ks_stats = []
178
197
198
+ pre_break = False
179
199
for i in range (hard_max ):
180
200
concrete_tests_ , runs_ = self ._generate_concrete_tests (sample_size , rct , seed + i )
181
- concrete_tests += concrete_tests_
201
+ for t_ in concrete_tests_ :
202
+ if not any ([vars (t_ ) == vars (t ) for t in concrete_tests ]):
203
+ concrete_tests .append (t_ )
182
204
runs = pd .concat ([runs , runs_ ])
183
205
assert concrete_tests_ not in concrete_tests , "Duplicate entries unlikely unless something went wrong"
184
206
185
- control_configs = pd .DataFrame ([test .control_input_configuration for test in concrete_tests ])
207
+ control_configs = pd .DataFrame ([{ test .treatment_variable : test . control_value } for test in concrete_tests ])
186
208
ks_stats = {
187
209
var : stats .kstest (control_configs [var ], var .distribution .cdf ).statistic
188
210
for var in control_configs .columns
@@ -205,14 +227,32 @@ def generate_concrete_tests(
205
227
for var in effect_modifier_configs .columns
206
228
}
207
229
)
208
- if target_ks_score and all ((stat <= target_ks_score for stat in ks_stats .values ())):
230
+ control_values = [test .control_value for test in concrete_tests ]
231
+ treatment_values = [test .treatment_value for test in concrete_tests ]
232
+
233
+ if self .treatment_variable .datatype is bool and set ([(True , False ), (False , True )]).issubset (
234
+ set (zip (control_values , treatment_values ))
235
+ ):
236
+ pre_break = True
237
+ break
238
+ if issubclass (self .treatment_variable .datatype , Enum ) and set (
239
+ {
240
+ (x , y )
241
+ for x , y in itertools .product (self .treatment_variable .datatype , self .treatment_variable .datatype )
242
+ if x != y
243
+ }
244
+ ).issubset (zip (control_values , treatment_values )):
245
+ pre_break = True
246
+ break
247
+ elif target_ks_score and all ((stat <= target_ks_score for stat in ks_stats .values ())):
248
+ pre_break = True
209
249
break
210
250
211
- if target_ks_score is not None and not all (( stat <= target_ks_score for stat in ks_stats . values ())) :
251
+ if target_ks_score is not None and not pre_break :
212
252
logger .error (
213
- "Hard max of %s reached but could not achieve target ks_score of %s. Got %s." ,
214
- hard_max ,
253
+ "Hard max reached but could not achieve target ks_score of %s. Got %s. Generated %s distinct tests" ,
215
254
target_ks_score ,
216
255
ks_stats ,
256
+ len (concrete_tests ),
217
257
)
218
258
return concrete_tests , runs
0 commit comments