4
4
import pandas as pd
5
5
import z3
6
6
from scipy import stats
7
+ import itertools
7
8
8
9
from causal_testing .specification .scenario import Scenario
9
10
from causal_testing .specification .variable import Variable
10
11
from causal_testing .testing .causal_test_case import CausalTestCase
11
12
from causal_testing .testing .causal_test_outcome import CausalTestOutcome
12
13
14
+ from enum import Enum
15
+
13
16
logger = logging .getLogger (__name__ )
14
17
15
18
@@ -24,23 +27,25 @@ def __init__(
24
27
self ,
25
28
scenario : Scenario ,
26
29
intervention_constraints : set [z3 .ExprRef ],
27
- treatment_variables : set [ Variable ] ,
30
+ treatment_variable : Variable ,
28
31
expected_causal_effect : dict [Variable :CausalTestOutcome ],
29
32
effect_modifiers : set [Variable ] = None ,
30
33
estimate_type : str = "ate" ,
34
+ effect : str = "total" ,
31
35
):
32
- assert treatment_variables . issubset ( scenario .variables .values () ), (
36
+ assert treatment_variable in scenario .variables .values (), (
33
37
"Treatment variables must be a subset of variables."
34
- + f" Instead got:\n treatment_variables= { treatment_variables } \n variables={ scenario .variables } "
38
+ + f" Instead got:\n treatment_variable= { treatment_variable } \n variables={ scenario .variables } "
35
39
)
36
40
37
41
assert len (expected_causal_effect ) == 1 , "We currently only support tests with one causal outcome"
38
42
39
43
self .scenario = scenario
40
44
self .intervention_constraints = intervention_constraints
41
- self .treatment_variables = treatment_variables
45
+ self .treatment_variable = treatment_variable
42
46
self .expected_causal_effect = expected_causal_effect
43
47
self .estimate_type = estimate_type
48
+ self .effect = effect
44
49
45
50
if effect_modifiers is not None :
46
51
self .effect_modifiers = effect_modifiers
@@ -100,7 +105,12 @@ def _generate_concrete_tests(
100
105
for c in self .intervention_constraints :
101
106
optimizer .assert_and_track (c , str (c ))
102
107
103
- optimizer .add_soft ([self .scenario .variables [v ].z3 == row [v ] for v in run_columns ])
108
+ for v in run_columns :
109
+ optimizer .add_soft (
110
+ self .scenario .variables [v ].z3
111
+ == self .scenario .variables [v ].z3_val (self .scenario .variables [v ].z3 , row [v ])
112
+ )
113
+
104
114
if optimizer .check () == z3 .unsat :
105
115
logger .warning (
106
116
"Satisfiability of test case was unsat.\n " "Constraints \n %s \n Unsat core %s" ,
@@ -110,14 +120,15 @@ def _generate_concrete_tests(
110
120
model = optimizer .model ()
111
121
112
122
concrete_test = CausalTestCase (
113
- control_input_configuration = {v : v .cast (model [v .z3 ]) for v in self .treatment_variables },
123
+ control_input_configuration = {v : v .cast (model [v .z3 ]) for v in [ self .treatment_variable ] },
114
124
treatment_input_configuration = {
115
- v : v .cast (model [self .scenario .treatment_variables [v .name ].z3 ]) for v in self .treatment_variables
125
+ v : v .cast (model [self .scenario .treatment_variables [v .name ].z3 ]) for v in [ self .treatment_variable ]
116
126
},
117
127
expected_causal_effect = list (self .expected_causal_effect .values ())[0 ],
118
128
outcome_variables = list (self .expected_causal_effect .keys ()),
119
129
estimate_type = self .estimate_type ,
120
130
effect_modifier_configuration = {v : v .cast (model [v .z3 ]) for v in self .effect_modifiers },
131
+ effect = self .effect ,
121
132
)
122
133
123
134
for v in self .scenario .inputs ():
@@ -128,19 +139,20 @@ def _generate_concrete_tests(
128
139
+ f"{ constraints } \n Using value { v .cast (model [v .z3 ])} instead in test\n { concrete_test } "
129
140
)
130
141
131
- concrete_tests .append (concrete_test )
132
- # Control run
133
- control_run = {
134
- v .name : v .cast (model [v .z3 ]) for v in self .scenario .variables .values () if v .name in run_columns
135
- }
136
- control_run ["bin" ] = index
137
- runs .append (control_run )
138
- # Treatment run
139
- if rct :
140
- treatment_run = control_run .copy ()
141
- treatment_run .update ({k .name : v for k , v in concrete_test .treatment_input_configuration .items ()})
142
- treatment_run ["bin" ] = index
143
- runs .append (treatment_run )
142
+ if not any ([vars (t ) == vars (concrete_test ) for t in concrete_tests ]):
143
+ concrete_tests .append (concrete_test )
144
+ # Control run
145
+ control_run = {
146
+ v .name : v .cast (model [v .z3 ]) for v in self .scenario .variables .values () if v .name in run_columns
147
+ }
148
+ control_run ["bin" ] = index
149
+ runs .append (control_run )
150
+ # Treatment run
151
+ if rct :
152
+ treatment_run = control_run .copy ()
153
+ treatment_run .update ({k .name : v for k , v in concrete_test .treatment_input_configuration .items ()})
154
+ treatment_run ["bin" ] = index
155
+ runs .append (treatment_run )
144
156
145
157
return concrete_tests , pd .DataFrame (runs , columns = run_columns + ["bin" ])
146
158
@@ -176,9 +188,12 @@ def generate_concrete_tests(
176
188
runs = pd .DataFrame ()
177
189
ks_stats = []
178
190
191
+ pre_break = False
179
192
for i in range (hard_max ):
180
193
concrete_tests_ , runs_ = self ._generate_concrete_tests (sample_size , rct , seed + i )
181
- concrete_tests += concrete_tests_
194
+ for t_ in concrete_tests_ :
195
+ if not any ([vars (t_ ) == vars (t ) for t in concrete_tests ]):
196
+ concrete_tests .append (t_ )
182
197
runs = pd .concat ([runs , runs_ ])
183
198
assert concrete_tests_ not in concrete_tests , "Duplicate entries unlikely unless something went wrong"
184
199
@@ -205,14 +220,32 @@ def generate_concrete_tests(
205
220
for var in effect_modifier_configs .columns
206
221
}
207
222
)
208
- if target_ks_score and all ((stat <= target_ks_score for stat in ks_stats .values ())):
223
+ control_values = [test .control_input_configuration [self .treatment_variable ] for test in concrete_tests ]
224
+ treatment_values = [test .treatment_input_configuration [self .treatment_variable ] for test in concrete_tests ]
225
+
226
+ if self .treatment_variable .datatype is bool and set ([(True , False ), (False , True )]).issubset (
227
+ set (zip (control_values , treatment_values ))
228
+ ):
229
+ pre_break = True
230
+ break
231
+ if issubclass (self .treatment_variable .datatype , Enum ) and set (
232
+ {
233
+ (x , y )
234
+ for x , y in itertools .product (self .treatment_variable .datatype , self .treatment_variable .datatype )
235
+ if x != y
236
+ }
237
+ ).issubset (zip (control_values , treatment_values )):
238
+ pre_break = True
239
+ break
240
+ elif target_ks_score and all ((stat <= target_ks_score for stat in ks_stats .values ())):
241
+ pre_break = True
209
242
break
210
243
211
- if target_ks_score is not None and not all (( stat <= target_ks_score for stat in ks_stats . values ())) :
244
+ if target_ks_score is not None and not pre_break :
212
245
logger .error (
213
- "Hard max of %s reached but could not achieve target ks_score of %s. Got %s." ,
214
- hard_max ,
246
+ "Hard max reached but could not achieve target ks_score of %s. Got %s. Generated %s distinct tests" ,
215
247
target_ks_score ,
216
248
ks_stats ,
249
+ len (concrete_tests ),
217
250
)
218
251
return concrete_tests , runs
0 commit comments