4
4
import pandas as pd
5
5
import z3
6
6
from scipy import stats
7
+ import itertools
7
8
8
9
from causal_testing .specification .scenario import Scenario
9
10
from causal_testing .specification .variable import Variable
10
11
from causal_testing .testing .causal_test_case import CausalTestCase
11
12
from causal_testing .testing .causal_test_outcome import CausalTestOutcome
12
13
14
+ from enum import Enum
15
+
13
16
logger = logging .getLogger (__name__ )
14
17
15
18
@@ -24,21 +27,21 @@ def __init__(
24
27
self ,
25
28
scenario : Scenario ,
26
29
intervention_constraints : set [z3 .ExprRef ],
27
- treatment_variables : set [ Variable ] ,
30
+ treatment_variable : Variable ,
28
31
expected_causal_effect : dict [Variable :CausalTestOutcome ],
29
32
effect_modifiers : set [Variable ] = None ,
30
33
estimate_type : str = "ate" ,
31
34
):
32
- assert treatment_variables . issubset ( scenario .variables .values () ), (
35
+ assert treatment_variable in scenario .variables .values (), (
33
36
"Treatment variables must be a subset of variables."
34
- + f" Instead got:\n treatment_variables= { treatment_variables } \n variables={ scenario .variables } "
37
+ + f" Instead got:\n treatment_variable= { treatment_variable } \n variables={ scenario .variables } "
35
38
)
36
39
37
40
assert len (expected_causal_effect ) == 1 , "We currently only support tests with one causal outcome"
38
41
39
42
self .scenario = scenario
40
43
self .intervention_constraints = intervention_constraints
41
- self .treatment_variables = treatment_variables
44
+ self .treatment_variable = treatment_variable
42
45
self .expected_causal_effect = expected_causal_effect
43
46
self .estimate_type = estimate_type
44
47
@@ -113,9 +116,9 @@ def _generate_concrete_tests(
113
116
model = optimizer .model ()
114
117
115
118
concrete_test = CausalTestCase (
116
- control_input_configuration = {v : v .cast (model [v .z3 ]) for v in self .treatment_variables },
119
+ control_input_configuration = {v : v .cast (model [v .z3 ]) for v in [ self .treatment_variable ] },
117
120
treatment_input_configuration = {
118
- v : v .cast (model [self .scenario .treatment_variables [v .name ].z3 ]) for v in self .treatment_variables
121
+ v : v .cast (model [self .scenario .treatment_variables [v .name ].z3 ]) for v in [ self .treatment_variable ]
119
122
},
120
123
expected_causal_effect = list (self .expected_causal_effect .values ())[0 ],
121
124
outcome_variables = list (self .expected_causal_effect .keys ()),
@@ -208,7 +211,13 @@ def generate_concrete_tests(
208
211
for var in effect_modifier_configs .columns
209
212
}
210
213
)
211
- if target_ks_score and all ((stat <= target_ks_score for stat in ks_stats .values ())):
214
+ print ("=== test ===" )
215
+ control_values = [test .control_input_configuration [self .treatment_variable ] for test in concrete_tests ]
216
+ treatment_values = [test .treatment_input_configuration [self .treatment_variable ] for test in concrete_tests ]
217
+
218
+ if issubclass (self .treatment_variable .datatype , Enum ) and set (zip (control_values , treatment_values )).issubset (itertools .product (self .treatment_variable .datatype , self .treatment_variable .datatype )):
219
+ break
220
+ elif target_ks_score and all ((stat <= target_ks_score for stat in ks_stats .values ())):
212
221
break
213
222
214
223
if target_ks_score is not None and not all ((stat <= target_ks_score for stat in ks_stats .values ())):
0 commit comments