4
4
import pandas as pd
5
5
import z3
6
6
from scipy import stats
7
+ import itertools
7
8
8
9
from causal_testing .specification .scenario import Scenario
9
10
from causal_testing .specification .variable import Variable
10
11
from causal_testing .testing .causal_test_case import CausalTestCase
11
12
from causal_testing .testing .causal_test_outcome import CausalTestOutcome
12
13
from causal_testing .testing .base_test_case import BaseTestCase
13
14
15
+ from enum import Enum
16
+
14
17
logger = logging .getLogger (__name__ )
15
18
16
19
@@ -25,23 +28,26 @@ def __init__(
25
28
self ,
26
29
scenario : Scenario ,
27
30
intervention_constraints : set [z3 .ExprRef ],
28
- treatment_variables : set [ Variable ] ,
31
+ treatment_variable : Variable ,
29
32
expected_causal_effect : dict [Variable :CausalTestOutcome ],
30
33
effect_modifiers : set [Variable ] = None ,
31
34
estimate_type : str = "ate" ,
35
+ effect : str = "total" ,
32
36
):
33
- assert {treatment_variables }.issubset (scenario .variables .values ()), (
34
- "Treatment variables must be a subset of variables."
35
- + f" Instead got:\n treatment_variables={ treatment_variables } \n variables={ scenario .variables } "
36
- )
37
+ if treatment_variable not in scenario .variables .values ():
38
+ raise ValueError (
39
+ "Treatment variables must be a subset of variables."
40
+ + f" Instead got:\n treatment_variables={ treatment_variable } \n variables={ scenario .variables } "
41
+ )
37
42
38
43
assert len (expected_causal_effect ) == 1 , "We currently only support tests with one causal outcome"
39
44
40
45
self .scenario = scenario
41
46
self .intervention_constraints = intervention_constraints
42
- self .treatment_variables = treatment_variables
47
+ self .treatment_variable = treatment_variable
43
48
self .expected_causal_effect = expected_causal_effect
44
49
self .estimate_type = estimate_type
50
+ self .effect = effect
45
51
46
52
if effect_modifiers is not None :
47
53
self .effect_modifiers = effect_modifiers
@@ -101,7 +107,12 @@ def _generate_concrete_tests(
101
107
for c in self .intervention_constraints :
102
108
optimizer .assert_and_track (c , str (c ))
103
109
104
- optimizer .add_soft ([self .scenario .variables [v ].z3 == row [v ] for v in run_columns ])
110
+ for v in run_columns :
111
+ optimizer .add_soft (
112
+ self .scenario .variables [v ].z3
113
+ == self .scenario .variables [v ].z3_val (self .scenario .variables [v ].z3 , row [v ])
114
+ )
115
+
105
116
if optimizer .check () == z3 .unsat :
106
117
logger .warning (
107
118
"Satisfiability of test case was unsat.\n " "Constraints \n %s \n Unsat core %s" ,
@@ -110,13 +121,17 @@ def _generate_concrete_tests(
110
121
)
111
122
model = optimizer .model ()
112
123
113
- base_test_case = BaseTestCase (self .treatment_variables , list (self .expected_causal_effect .keys ())[0 ])
124
+ base_test_case = BaseTestCase (
125
+ treatment_variable = self .treatment_variable ,
126
+ outcome_variable = list (self .expected_causal_effect .keys ())[0 ],
127
+ effect = self .effect ,
128
+ )
114
129
115
130
concrete_test = CausalTestCase (
116
131
base_test_case = base_test_case ,
117
- control_value = self .treatment_variables .cast (model [self .treatment_variables .z3 ]),
118
- treatment_value = self .treatment_variables .cast (
119
- model [self .scenario .treatment_variables [self .treatment_variables .name ].z3 ]
132
+ control_value = self .treatment_variable .cast (model [self .treatment_variable .z3 ]),
133
+ treatment_value = self .treatment_variable .cast (
134
+ model [self .scenario .treatment_variables [self .treatment_variable .name ].z3 ]
120
135
),
121
136
expected_causal_effect = list (self .expected_causal_effect .values ())[0 ],
122
137
estimate_type = self .estimate_type ,
@@ -131,20 +146,22 @@ def _generate_concrete_tests(
131
146
+ f"{ constraints } \n Using value { v .cast (model [v .z3 ])} instead in test\n { concrete_test } "
132
147
)
133
148
134
- concrete_tests .append (concrete_test )
135
- # Control run
136
- control_run = {
137
- v .name : v .cast (model [v .z3 ]) for v in self .scenario .variables .values () if v .name in run_columns
138
- }
139
- control_run ["bin" ] = index
140
- runs .append (control_run )
141
- # Treatment run
142
- if rct :
143
- treatment_run = control_run .copy ()
144
- treatment_run .update ({concrete_test .treatment_variable .name : concrete_test .treatment_value })
145
- # treatment_run.update({k.name: v for k, v in concrete_test.treatment_input_configuration.items()})
146
- treatment_run ["bin" ] = index
147
- runs .append (treatment_run )
149
+
150
+ if not any ([vars (t ) == vars (concrete_test ) for t in concrete_tests ]):
151
+ concrete_tests .append (concrete_test )
152
+ # Control run
153
+ control_run = {
154
+ v .name : v .cast (model [v .z3 ]) for v in self .scenario .variables .values () if v .name in run_columns
155
+ }
156
+ control_run ["bin" ] = index
157
+ runs .append (control_run )
158
+ # Treatment run
159
+ if rct :
160
+ treatment_run = control_run .copy ()
161
+ treatment_run .update ({concrete_test .treatment_variable .name : concrete_test .treatment_value })
162
+ treatment_run ["bin" ] = index
163
+ runs .append (treatment_run )
164
+
148
165
149
166
return concrete_tests , pd .DataFrame (runs , columns = run_columns + ["bin" ])
150
167
@@ -180,9 +197,12 @@ def generate_concrete_tests(
180
197
runs = pd .DataFrame ()
181
198
ks_stats = []
182
199
200
+ pre_break = False
183
201
for i in range (hard_max ):
184
202
concrete_tests_ , runs_ = self ._generate_concrete_tests (sample_size , rct , seed + i )
185
- concrete_tests += concrete_tests_
203
+ for t_ in concrete_tests_ :
204
+ if not any ([vars (t_ ) == vars (t ) for t in concrete_tests ]):
205
+ concrete_tests .append (t_ )
186
206
runs = pd .concat ([runs , runs_ ])
187
207
assert concrete_tests_ not in concrete_tests , "Duplicate entries unlikely unless something went wrong"
188
208
@@ -209,14 +229,32 @@ def generate_concrete_tests(
209
229
for var in effect_modifier_configs .columns
210
230
}
211
231
)
212
- if target_ks_score and all ((stat <= target_ks_score for stat in ks_stats .values ())):
232
+ control_values = [test .control_input_configuration [self .treatment_variable ] for test in concrete_tests ]
233
+ treatment_values = [test .treatment_input_configuration [self .treatment_variable ] for test in concrete_tests ]
234
+
235
+ if self .treatment_variable .datatype is bool and set ([(True , False ), (False , True )]).issubset (
236
+ set (zip (control_values , treatment_values ))
237
+ ):
238
+ pre_break = True
239
+ break
240
+ if issubclass (self .treatment_variable .datatype , Enum ) and set (
241
+ {
242
+ (x , y )
243
+ for x , y in itertools .product (self .treatment_variable .datatype , self .treatment_variable .datatype )
244
+ if x != y
245
+ }
246
+ ).issubset (zip (control_values , treatment_values )):
247
+ pre_break = True
248
+ break
249
+ elif target_ks_score and all ((stat <= target_ks_score for stat in ks_stats .values ())):
250
+ pre_break = True
213
251
break
214
252
215
- if target_ks_score is not None and not all (( stat <= target_ks_score for stat in ks_stats . values ())) :
253
+ if target_ks_score is not None and not pre_break :
216
254
logger .error (
217
- "Hard max of %s reached but could not achieve target ks_score of %s. Got %s." ,
218
- hard_max ,
255
+ "Hard max reached but could not achieve target ks_score of %s. Got %s. Generated %s distinct tests" ,
219
256
target_ks_score ,
220
257
ks_stats ,
258
+ len (concrete_tests ),
221
259
)
222
260
return concrete_tests , runs
0 commit comments