@@ -25,45 +25,52 @@ def __init__(
25
25
scenario : Scenario ,
26
26
intervention_constraints : set [z3 .ExprRef ],
27
27
treatment_variables : set [Variable ],
28
- expected_causal_effect : dict [Variable : CausalTestOutcome ],
28
+ expected_causal_effect : dict [Variable :CausalTestOutcome ],
29
29
effect_modifiers : set [Variable ] = None ,
30
- estimate_type : str = "ate"
30
+ estimate_type : str = "ate" ,
31
31
):
32
32
assert treatment_variables .issubset (scenario .variables .values ()), (
33
33
"Treatment variables must be a subset of variables."
34
34
+ f" Instead got:\n treatment_variables={ treatment_variables } \n variables={ scenario .variables } "
35
35
)
36
36
37
- assert len (expected_causal_effect ) == 1 , "We currently only support tests with one causal outcome"
37
+ assert (
38
+ len (expected_causal_effect ) == 1
39
+ ), "We currently only support tests with one causal outcome"
38
40
39
41
self .scenario = scenario
40
42
self .intervention_constraints = intervention_constraints
41
43
self .treatment_variables = treatment_variables
42
44
self .expected_causal_effect = expected_causal_effect
43
- self .estimate_type = estimate_type
45
+ self .estimate_type = estimate_type
44
46
45
47
if effect_modifiers is not None :
46
48
self .effect_modifiers = effect_modifiers
47
49
else :
48
50
self .effect_modifiers = {}
49
51
50
52
def __str__ (self ):
51
- outcome_string = " and " .join ([f"the effect on { var } should be { str (effect )} " for var , effect in self .expected_causal_effect .items ()])
52
- return (
53
- f"When we apply intervention { self .intervention_constraints } , { outcome_string } "
53
+ outcome_string = " and " .join (
54
+ [
55
+ f"the effect on { var } should be { str (effect )} "
56
+ for var , effect in self .expected_causal_effect .items ()
57
+ ]
54
58
)
59
+ return f"When we apply intervention { self .intervention_constraints } , { outcome_string } "
55
60
56
61
def datapath (self ):
57
62
def sanitise (string ):
58
63
return "" .join ([x for x in string if x .isalnum ()])
59
64
60
65
return (
61
66
sanitise ("-" .join ([str (c ) for c in self .intervention_constraints ]))
62
- + "_" + '-' .join ([f"{ v .name } _{ e } " for v , e in self .expected_causal_effect .items ()])
67
+ + "_"
68
+ + "-" .join (
69
+ [f"{ v .name } _{ e } " for v , e in self .expected_causal_effect .items ()]
70
+ )
63
71
+ ".csv"
64
72
)
65
73
66
-
67
74
def _generate_concrete_tests (
68
75
self , sample_size : int , rct : bool = False , seed : int = 0
69
76
) -> tuple [list [CausalTestCase ], pd .DataFrame ]:
@@ -80,8 +87,9 @@ def _generate_concrete_tests(
80
87
81
88
concrete_tests = []
82
89
runs = []
83
- run_columns = sorted ([v .name for v in self .scenario .variables .values () if v .distribution ])
84
-
90
+ run_columns = sorted (
91
+ [v .name for v in self .scenario .variables .values () if v .distribution ]
92
+ )
85
93
86
94
# Generate the Latin Hypercube samples and put into a dataframe
87
95
# lhsmdu.setRandomSeed(seed+i)
@@ -103,7 +111,9 @@ def _generate_concrete_tests(
103
111
for c in self .intervention_constraints :
104
112
optimizer .assert_and_track (c , str (c ))
105
113
106
- optimizer .add_soft ([self .scenario .variables [v ].z3 == row [v ] for v in run_columns ])
114
+ optimizer .add_soft (
115
+ [self .scenario .variables [v ].z3 == row [v ] for v in run_columns ]
116
+ )
107
117
if optimizer .check () == z3 .unsat :
108
118
logger .warning (
109
119
"Satisfiability of test case was unsat.\n "
@@ -122,9 +132,9 @@ def _generate_concrete_tests(
122
132
expected_causal_effect = list (self .expected_causal_effect .values ())[0 ],
123
133
outcome_variables = list (self .expected_causal_effect .keys ()),
124
134
estimate_type = self .estimate_type ,
125
- effect_modifier_configuration = {
135
+ effect_modifier_configuration = {
126
136
v : v .cast (model [v .z3 ]) for v in self .effect_modifiers
127
- }
137
+ },
128
138
)
129
139
130
140
for v in self .scenario .inputs ():
@@ -160,9 +170,13 @@ def _generate_concrete_tests(
160
170
161
171
return concrete_tests , pd .DataFrame (runs , columns = run_columns + ["bin" ])
162
172
163
-
164
173
def generate_concrete_tests (
165
- self , sample_size : int , target_ks_score : float = None , rct : bool = False , seed : int = 0 , hard_max : int = 1000
174
+ self ,
175
+ sample_size : int ,
176
+ target_ks_score : float = None ,
177
+ rct : bool = False ,
178
+ seed : int = 0 ,
179
+ hard_max : int = 1000 ,
166
180
) -> tuple [list [CausalTestCase ], pd .DataFrame ]:
167
181
"""Generates a list of `num` concrete test cases.
168
182
@@ -189,14 +203,22 @@ def generate_concrete_tests(
189
203
ks_stats = []
190
204
191
205
for i in range (hard_max ):
192
- concrete_tests_ , runs_ = self ._generate_concrete_tests (sample_size , rct , seed + i )
206
+ concrete_tests_ , runs_ = self ._generate_concrete_tests (
207
+ sample_size , rct , seed + i
208
+ )
193
209
concrete_tests += concrete_tests_
194
210
runs = pd .concat ([runs , runs_ ])
195
- assert concrete_tests_ not in concrete_tests , "Duplicate entries unlikely unless something went wrong"
196
-
211
+ assert (
212
+ concrete_tests_ not in concrete_tests
213
+ ), "Duplicate entries unlikely unless something went wrong"
197
214
198
- control_configs = pd .DataFrame ([test .control_input_configuration for test in concrete_tests ])
199
- ks_stats = {var : stats .kstest (control_configs [var ], var .distribution .cdf ).statistic for var in control_configs .columns }
215
+ control_configs = pd .DataFrame (
216
+ [test .control_input_configuration for test in concrete_tests ]
217
+ )
218
+ ks_stats = {
219
+ var : stats .kstest (control_configs [var ], var .distribution .cdf ).statistic
220
+ for var in control_configs .columns
221
+ }
200
222
# Putting treatment and control values in messes it up because the two are not independent...
201
223
# This is potentially problematic as constraints might mean we don't get good coverage if we use control values alone
202
224
# We might then need to carefully craft our _control value_ generating distributions so that we can get good coverage
@@ -205,11 +227,29 @@ def generate_concrete_tests(
205
227
# treatment_configs = pd.DataFrame([test.treatment_input_configuration for test in concrete_tests])
206
228
# both_configs = pd.concat([control_configs, treatment_configs])
207
229
# ks_stats = {var: stats.kstest(both_configs[var], var.distribution.cdf).statistic for var in both_configs.columns}
208
- effect_modifier_configs = pd .DataFrame ([test .effect_modifier_configuration for test in concrete_tests ])
209
- ks_stats .update ({var : stats .kstest (effect_modifier_configs [var ], var .distribution .cdf ).statistic for var in effect_modifier_configs .columns })
210
- if target_ks_score and all ((stat <= target_ks_score for stat in ks_stats .values ())):
230
+ effect_modifier_configs = pd .DataFrame (
231
+ [test .effect_modifier_configuration for test in concrete_tests ]
232
+ )
233
+ ks_stats .update (
234
+ {
235
+ var : stats .kstest (
236
+ effect_modifier_configs [var ], var .distribution .cdf
237
+ ).statistic
238
+ for var in effect_modifier_configs .columns
239
+ }
240
+ )
241
+ if target_ks_score and all (
242
+ (stat <= target_ks_score for stat in ks_stats .values ())
243
+ ):
211
244
break
212
245
213
- if target_ks_score is not None and not all ((stat <= target_ks_score for stat in ks_stats .values ())):
214
- logger .error ("Hard max of %s reached but could not achieve target ks_score of %s. Got %s." , hard_max , target_ks_score , ks_stats )
246
+ if target_ks_score is not None and not all (
247
+ (stat <= target_ks_score for stat in ks_stats .values ())
248
+ ):
249
+ logger .error (
250
+ "Hard max of %s reached but could not achieve target ks_score of %s. Got %s." ,
251
+ hard_max ,
252
+ target_ks_score ,
253
+ ks_stats ,
254
+ )
215
255
return concrete_tests , runs
0 commit comments