Skip to content

Commit 479119d

Browse files
committed
Stop generation if all combinations are in
1 parent 647927c commit 479119d

File tree

3 files changed

+21
-10
lines changed

3 files changed

+21
-10
lines changed

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
import pandas as pd
55
import z3
66
from scipy import stats
7+
import itertools
78

89
from causal_testing.specification.scenario import Scenario
910
from causal_testing.specification.variable import Variable
1011
from causal_testing.testing.causal_test_case import CausalTestCase
1112
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
1213

14+
from enum import Enum
15+
1316
logger = logging.getLogger(__name__)
1417

1518

@@ -24,21 +27,21 @@ def __init__(
2427
self,
2528
scenario: Scenario,
2629
intervention_constraints: set[z3.ExprRef],
27-
treatment_variables: set[Variable],
30+
treatment_variable: Variable,
2831
expected_causal_effect: dict[Variable:CausalTestOutcome],
2932
effect_modifiers: set[Variable] = None,
3033
estimate_type: str = "ate",
3134
):
32-
assert treatment_variables.issubset(scenario.variables.values()), (
35+
assert treatment_variable in scenario.variables.values(), (
3336
"Treatment variables must be a subset of variables."
34-
+ f" Instead got:\ntreatment_variables={treatment_variables}\nvariables={scenario.variables}"
37+
+ f" Instead got:\ntreatment_variable={treatment_variable}\nvariables={scenario.variables}"
3538
)
3639

3740
assert len(expected_causal_effect) == 1, "We currently only support tests with one causal outcome"
3841

3942
self.scenario = scenario
4043
self.intervention_constraints = intervention_constraints
41-
self.treatment_variables = treatment_variables
44+
self.treatment_variable = treatment_variable
4245
self.expected_causal_effect = expected_causal_effect
4346
self.estimate_type = estimate_type
4447

@@ -113,9 +116,9 @@ def _generate_concrete_tests(
113116
model = optimizer.model()
114117

115118
concrete_test = CausalTestCase(
116-
control_input_configuration={v: v.cast(model[v.z3]) for v in self.treatment_variables},
119+
control_input_configuration={v: v.cast(model[v.z3]) for v in [self.treatment_variable]},
117120
treatment_input_configuration={
118-
v: v.cast(model[self.scenario.treatment_variables[v.name].z3]) for v in self.treatment_variables
121+
v: v.cast(model[self.scenario.treatment_variables[v.name].z3]) for v in [self.treatment_variable]
119122
},
120123
expected_causal_effect=list(self.expected_causal_effect.values())[0],
121124
outcome_variables=list(self.expected_causal_effect.keys()),
@@ -208,7 +211,13 @@ def generate_concrete_tests(
208211
for var in effect_modifier_configs.columns
209212
}
210213
)
211-
if target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
214+
print("=== test ===")
215+
control_values = [test.control_input_configuration[self.treatment_variable] for test in concrete_tests]
216+
treatment_values = [test.treatment_input_configuration[self.treatment_variable] for test in concrete_tests]
217+
218+
if issubclass(self.treatment_variable.datatype, Enum) and set(zip(control_values, treatment_values)).issubset(itertools.product(self.treatment_variable.datatype, self.treatment_variable.datatype)):
219+
break
220+
elif target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
212221
break
213222

214223
if target_ks_score is not None and not all((stat <= target_ks_score for stat in ks_stats.values())):

causal_testing/json_front/json_class.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,12 @@ def setup(self):
8989
self._populate_metas()
9090

9191
def _create_abstract_test_case(self, test, mutates, effects):
92+
assert len(test["mutations"]) == 1
9293
abstract_test = AbstractCausalTestCase(
9394
scenario=self.modelling_scenario,
9495
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
95-
treatment_variables={self.modelling_scenario.variables[v] for v in test["mutations"]},
96+
# TODO: Could change JSON to be treatment_var and mutation to it rather than a dict of mutations
97+
treatment_variable=next(self.modelling_scenario.variables[v] for v in test["mutations"]),
9698
expected_causal_effect={
9799
self.modelling_scenario.variables[variable]: effects[effect]
98100
for variable, effect in test["expectedEffect"].items()
@@ -121,7 +123,7 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
121123
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
122124
logger.info("Executing test: %s", test["name"])
123125
logger.info(abstract_test)
124-
logger.info([(v.name, v.distribution) for v in abstract_test.treatment_variables])
126+
logger.info([(v.name, v.distribution) for v in [abstract_test.treatment_variable]])
125127
logger.info("Number of concrete tests for test case: %s", str(len(concrete_tests)))
126128
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
127129
logger.info("%s/%s failed", failures, len(concrete_tests))

causal_testing/testing/estimators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def estimate_ate(self) -> float:
203203
estimate = treatment_outcome - control_outcome
204204

205205
logger.info(
206-
f"Changing {self.treatment} from {self.control_values} to {self.treatment_values} gives an estimated ATE of {ci_low} < {estimate} < {ci_high}"
206+
f"Changing {self.treatment[0]} from {self.control_values} to {self.treatment_values} gives an estimated ATE of {ci_low} < {estimate} < {ci_high}"
207207
)
208208
assert ci_low < estimate < ci_high, f"Expecting {ci_low} < {estimate} < {ci_high}"
209209

0 commit comments

Comments
 (0)