Skip to content

Commit 92ccc5d

Browse files
authored
Merge pull request #125 from CITCOM-project/enum_variables
Working support for ENUM variables
2 parents 024948f + fcb25c2 commit 92ccc5d

File tree

12 files changed

+389
-123
lines changed

12 files changed

+389
-123
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from abc import ABC, abstractmethod
3+
from enum import Enum
34

45
import pandas as pd
56
import z3
@@ -140,4 +141,7 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
140141
for meta in self.scenario.metas():
141142
meta.populate(execution_data_df)
142143
scenario_execution_data_df = self.filter_valid_data(execution_data_df)
144+
for var_name, var in self.scenario.variables.items():
145+
if issubclass(var.datatype, Enum):
146+
scenario_execution_data_df[var_name] = [var.datatype(x) for x in scenario_execution_data_df[var_name]]
143147
return scenario_execution_data_df

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
import pandas as pd
55
import z3
66
from scipy import stats
7+
import itertools
78

89
from causal_testing.specification.scenario import Scenario
910
from causal_testing.specification.variable import Variable
1011
from causal_testing.testing.causal_test_case import CausalTestCase
1112
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
1213

14+
from enum import Enum
15+
1316
logger = logging.getLogger(__name__)
1417

1518

@@ -24,23 +27,25 @@ def __init__(
2427
self,
2528
scenario: Scenario,
2629
intervention_constraints: set[z3.ExprRef],
27-
treatment_variables: set[Variable],
30+
treatment_variable: Variable,
2831
expected_causal_effect: dict[Variable:CausalTestOutcome],
2932
effect_modifiers: set[Variable] = None,
3033
estimate_type: str = "ate",
34+
effect: str = "total",
3135
):
32-
assert treatment_variables.issubset(scenario.variables.values()), (
36+
assert treatment_variable in scenario.variables.values(), (
3337
"Treatment variables must be a subset of variables."
34-
+ f" Instead got:\ntreatment_variables={treatment_variables}\nvariables={scenario.variables}"
38+
+ f" Instead got:\ntreatment_variable={treatment_variable}\nvariables={scenario.variables}"
3539
)
3640

3741
assert len(expected_causal_effect) == 1, "We currently only support tests with one causal outcome"
3842

3943
self.scenario = scenario
4044
self.intervention_constraints = intervention_constraints
41-
self.treatment_variables = treatment_variables
45+
self.treatment_variable = treatment_variable
4246
self.expected_causal_effect = expected_causal_effect
4347
self.estimate_type = estimate_type
48+
self.effect = effect
4449

4550
if effect_modifiers is not None:
4651
self.effect_modifiers = effect_modifiers
@@ -100,7 +105,12 @@ def _generate_concrete_tests(
100105
for c in self.intervention_constraints:
101106
optimizer.assert_and_track(c, str(c))
102107

103-
optimizer.add_soft([self.scenario.variables[v].z3 == row[v] for v in run_columns])
108+
for v in run_columns:
109+
optimizer.add_soft(
110+
self.scenario.variables[v].z3
111+
== self.scenario.variables[v].z3_val(self.scenario.variables[v].z3, row[v])
112+
)
113+
104114
if optimizer.check() == z3.unsat:
105115
logger.warning(
106116
"Satisfiability of test case was unsat.\n" "Constraints \n %s \n Unsat core %s",
@@ -110,14 +120,15 @@ def _generate_concrete_tests(
110120
model = optimizer.model()
111121

112122
concrete_test = CausalTestCase(
113-
control_input_configuration={v: v.cast(model[v.z3]) for v in self.treatment_variables},
123+
control_input_configuration={v: v.cast(model[v.z3]) for v in [self.treatment_variable]},
114124
treatment_input_configuration={
115-
v: v.cast(model[self.scenario.treatment_variables[v.name].z3]) for v in self.treatment_variables
125+
v: v.cast(model[self.scenario.treatment_variables[v.name].z3]) for v in [self.treatment_variable]
116126
},
117127
expected_causal_effect=list(self.expected_causal_effect.values())[0],
118128
outcome_variables=list(self.expected_causal_effect.keys()),
119129
estimate_type=self.estimate_type,
120130
effect_modifier_configuration={v: v.cast(model[v.z3]) for v in self.effect_modifiers},
131+
effect=self.effect,
121132
)
122133

123134
for v in self.scenario.inputs():
@@ -128,19 +139,20 @@ def _generate_concrete_tests(
128139
+ f"{constraints}\nUsing value {v.cast(model[v.z3])} instead in test\n{concrete_test}"
129140
)
130141

131-
concrete_tests.append(concrete_test)
132-
# Control run
133-
control_run = {
134-
v.name: v.cast(model[v.z3]) for v in self.scenario.variables.values() if v.name in run_columns
135-
}
136-
control_run["bin"] = index
137-
runs.append(control_run)
138-
# Treatment run
139-
if rct:
140-
treatment_run = control_run.copy()
141-
treatment_run.update({k.name: v for k, v in concrete_test.treatment_input_configuration.items()})
142-
treatment_run["bin"] = index
143-
runs.append(treatment_run)
142+
if not any([vars(t) == vars(concrete_test) for t in concrete_tests]):
143+
concrete_tests.append(concrete_test)
144+
# Control run
145+
control_run = {
146+
v.name: v.cast(model[v.z3]) for v in self.scenario.variables.values() if v.name in run_columns
147+
}
148+
control_run["bin"] = index
149+
runs.append(control_run)
150+
# Treatment run
151+
if rct:
152+
treatment_run = control_run.copy()
153+
treatment_run.update({k.name: v for k, v in concrete_test.treatment_input_configuration.items()})
154+
treatment_run["bin"] = index
155+
runs.append(treatment_run)
144156

145157
return concrete_tests, pd.DataFrame(runs, columns=run_columns + ["bin"])
146158

@@ -176,9 +188,12 @@ def generate_concrete_tests(
176188
runs = pd.DataFrame()
177189
ks_stats = []
178190

191+
pre_break = False
179192
for i in range(hard_max):
180193
concrete_tests_, runs_ = self._generate_concrete_tests(sample_size, rct, seed + i)
181-
concrete_tests += concrete_tests_
194+
for t_ in concrete_tests_:
195+
if not any([vars(t_) == vars(t) for t in concrete_tests]):
196+
concrete_tests.append(t_)
182197
runs = pd.concat([runs, runs_])
183198
assert concrete_tests_ not in concrete_tests, "Duplicate entries unlikely unless something went wrong"
184199

@@ -205,14 +220,32 @@ def generate_concrete_tests(
205220
for var in effect_modifier_configs.columns
206221
}
207222
)
208-
if target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
223+
control_values = [test.control_input_configuration[self.treatment_variable] for test in concrete_tests]
224+
treatment_values = [test.treatment_input_configuration[self.treatment_variable] for test in concrete_tests]
225+
226+
if self.treatment_variable.datatype is bool and set([(True, False), (False, True)]).issubset(
227+
set(zip(control_values, treatment_values))
228+
):
229+
pre_break = True
230+
break
231+
if issubclass(self.treatment_variable.datatype, Enum) and set(
232+
{
233+
(x, y)
234+
for x, y in itertools.product(self.treatment_variable.datatype, self.treatment_variable.datatype)
235+
if x != y
236+
}
237+
).issubset(zip(control_values, treatment_values)):
238+
pre_break = True
239+
break
240+
elif target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
241+
pre_break = True
209242
break
210243

211-
if target_ks_score is not None and not all((stat <= target_ks_score for stat in ks_stats.values())):
244+
if target_ks_score is not None and not pre_break:
212245
logger.error(
213-
"Hard max of %s reached but could not achieve target ks_score of %s. Got %s.",
214-
hard_max,
246+
"Hard max reached but could not achieve target ks_score of %s. Got %s. Generated %s distinct tests",
215247
target_ks_score,
216248
ks_stats,
249+
len(concrete_tests),
217250
)
218251
return concrete_tests, runs

causal_testing/json_front/json_class.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def set_variables(self, inputs: dict, outputs: dict, metas: dict):
7575
:param metas:
7676
"""
7777
self.inputs = [Input(i["name"], i["type"], i["distribution"]) for i in inputs]
78-
self.outputs = [Output(i["name"], i["type"]) for i in outputs]
78+
self.outputs = [Output(i["name"], i["type"], i.get("distribution", None)) for i in outputs]
7979
self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else []
8080

8181
def setup(self):
@@ -89,10 +89,11 @@ def setup(self):
8989
self._populate_metas()
9090

9191
def _create_abstract_test_case(self, test, mutates, effects):
92+
assert len(test["mutations"]) == 1
9293
abstract_test = AbstractCausalTestCase(
9394
scenario=self.modelling_scenario,
9495
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
95-
treatment_variables={self.modelling_scenario.variables[v] for v in test["mutations"]},
96+
treatment_variable=next(self.modelling_scenario.variables[v] for v in test["mutations"]),
9697
expected_causal_effect={
9798
self.modelling_scenario.variables[variable]: effects[effect]
9899
for variable, effect in test["expectedEffect"].items()
@@ -101,6 +102,7 @@ def _create_abstract_test_case(self, test, mutates, effects):
101102
if "effect_modifiers" in test
102103
else {},
103104
estimate_type=test["estimate_type"],
105+
effect=test.get("effect", "total"),
104106
)
105107
return abstract_test
106108

@@ -121,10 +123,10 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
121123
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
122124
logger.info("Executing test: %s", test["name"])
123125
logger.info(abstract_test)
124-
logger.info([(v.name, v.distribution) for v in abstract_test.treatment_variables])
126+
logger.info([(v.name, v.distribution) for v in [abstract_test.treatment_variable]])
125127
logger.info("Number of concrete tests for test case: %s", str(len(concrete_tests)))
126128
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
127-
logger.info("%s/%s failed", failures, len(concrete_tests))
129+
logger.info("%s/%s failed for %s\n", failures, len(concrete_tests), test["name"])
128130

129131
def _execute_tests(self, concrete_tests, estimators, test, f_flag):
130132
failures = 0
@@ -151,11 +153,12 @@ def _populate_metas(self):
151153
meta.populate(self.data)
152154

153155
for var in self.metas + self.outputs:
154-
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
155-
fitter.fit()
156-
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
157-
var.distribution = getattr(scipy.stats, dist)(**params)
158-
logger.info(var.name + f"{dist}({params})")
156+
if not var.distribution:
157+
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
158+
fitter.fit()
159+
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
160+
var.distribution = getattr(scipy.stats, dist)(**params)
161+
logger.info(var.name + f" {dist}({params})")
159162

160163
def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estimator, f_flag: bool) -> bool:
161164
"""Executes a singular test case, prints the results and returns the test case result
@@ -178,19 +181,15 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
178181
if causal_test_result.ci_low() and causal_test_result.ci_high():
179182
result_string = f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < {causal_test_result.ci_high()}"
180183
else:
181-
result_string = causal_test_result.test_value.value
184+
result_string = f"{causal_test_result.test_value.value} no confidence intervals"
182185
if f_flag:
183186
assert test_passes, (
184187
f"{causal_test_case}\n FAILED - expected {causal_test_case.expected_causal_effect}, "
185188
f"got {result_string}"
186189
)
187190
if not test_passes:
188191
failed = True
189-
logger.warning(
190-
" FAILED- expected %s, got %s",
191-
causal_test_case.expected_causal_effect,
192-
causal_test_result.test_value.value,
193-
)
192+
logger.warning(" FAILED- expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
194193
return failed
195194

196195
def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) -> tuple[CausalTestEngine, Estimator]:

causal_testing/specification/causal_dag.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st
255255
gam.add_edges_from(edges_to_add)
256256

257257
min_seps = list(list_all_min_sep(gam, "TREATMENT", "OUTCOME", set(treatments), set(outcomes)))
258-
# min_seps.remove(set(outcomes))
258+
if set(outcomes) in min_seps:
259+
min_seps.remove(set(outcomes))
259260
return min_seps
260261

261262
def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
@@ -278,6 +279,7 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
278279
:param outcomes: A list of strings representing outcomes.
279280
:return: A list of strings representing the minimal adjustment set.
280281
"""
282+
281283
# 1. Construct the proper back-door graph's ancestor moral graph
282284
proper_backdoor_graph = self.get_proper_backdoor_graph(treatments, outcomes)
283285
ancestor_proper_backdoor_graph = proper_backdoor_graph.get_ancestor_graph(treatments, outcomes)
@@ -316,6 +318,7 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
316318
for adj in minimum_adjustment_sets
317319
if self.constructive_backdoor_criterion(proper_backdoor_graph, treatments, outcomes, adj)
318320
]
321+
319322
return valid_minimum_adjustment_sets
320323

321324
def adjustment_set_is_minimal(self, treatments: list[str], outcomes: list[str], adjustment_set: set[str]) -> bool:

causal_testing/specification/variable.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import lhsmdu
77
from pandas import DataFrame
88
from scipy.stats._distn_infrastructure import rv_generic
9-
from z3 import Bool, BoolRef, Const, EnumSort, Int, RatNumRef, Real, String
9+
from z3 import Bool, BoolRef, Const, EnumSort, Int, RatNumRef, Real, String, DatatypeRef
1010

1111
# Declare type variable
1212
# Is there a better way? I'd really like to do Variable[T](ExprRef)
@@ -22,7 +22,7 @@ def z3_types(datatype):
2222
if datatype in types:
2323
return types[datatype]
2424
if issubclass(datatype, Enum):
25-
dtype, _ = EnumSort(datatype.__name__, [x.name for x in datatype])
25+
dtype, _ = EnumSort(datatype.__name__, [str(x.value) for x in datatype])
2626
return lambda x: Const(x, dtype)
2727
if hasattr(datatype, "to_z3"):
2828
return datatype.to_z3()
@@ -153,19 +153,27 @@ def cast(self, val: Any) -> T:
153153
:rtype: T
154154
"""
155155
assert val is not None, f"Invalid value None for variable {self}"
156+
if isinstance(val, self.datatype):
157+
return val
158+
if isinstance(val, BoolRef) and self.datatype == bool:
159+
return str(val) == "True"
156160
if isinstance(val, RatNumRef) and self.datatype == float:
157161
return float(val.numerator().as_long() / val.denominator().as_long())
158162
if hasattr(val, "is_string_value") and val.is_string_value() and self.datatype == str:
159163
return val.as_string()
160-
if (isinstance(val, float) or isinstance(val, int)) and (self.datatype == int or self.datatype == float):
164+
if (isinstance(val, float) or isinstance(val, int) or isinstance(val, bool)) and (
165+
self.datatype == int or self.datatype == float or self.datatype == bool
166+
):
161167
return self.datatype(val)
168+
if issubclass(self.datatype, Enum) and isinstance(val, DatatypeRef):
169+
return self.datatype(str(val))
162170
return self.datatype(str(val))
163171

164172
def z3_val(self, z3_var, val: Any) -> T:
165173
native_val = self.cast(val)
166174
if isinstance(native_val, Enum):
167175
values = [z3_var.sort().constructor(c)() for c in range(z3_var.sort().num_constructors())]
168-
values = [v for v in values if str(v) == str(val)]
176+
values = [v for v in values if val.__class__(str(v)) == val]
169177
assert len(values) == 1, f"Expected {values} to be length 1"
170178
return values[0]
171179
return native_val
@@ -193,7 +201,6 @@ def typestring(self) -> str:
193201
"""
194202
return type(self).__name__
195203

196-
@abstractmethod
197204
def copy(self, name: str = None) -> Variable:
198205
"""Return a new instance of the Variable with the given name, or with
199206
the original name if no name is supplied.
@@ -203,26 +210,18 @@ def copy(self, name: str = None) -> Variable:
203210
:rtype: Variable
204211
205212
"""
206-
raise NotImplementedError("Method `copy` must be instantiated.")
213+
if name:
214+
return self.__class__(name, self.datatype, self.distribution)
215+
return self.__class__(self.name, self.datatype, self.distribution)
207216

208217

209218
class Input(Variable):
210219
"""An extension of the Variable class representing inputs."""
211220

212-
def copy(self, name=None) -> Input:
213-
if name:
214-
return Input(name, self.datatype, self.distribution)
215-
return Input(self.name, self.datatype, self.distribution)
216-
217221

218222
class Output(Variable):
219223
"""An extension of the Variable class representing outputs."""
220224

221-
def copy(self, name=None) -> Output:
222-
if name:
223-
return Output(name, self.datatype, self.distribution)
224-
return Output(self.name, self.datatype, self.distribution)
225-
226225

227226
class Meta(Variable):
228227
"""An extension of the Variable class representing metavariables. These are variables which are relevant to the
@@ -242,8 +241,3 @@ class Meta(Variable):
242241
def __init__(self, name: str, datatype: T, populate: Callable[[DataFrame], DataFrame]):
243242
super().__init__(name, datatype)
244243
self.populate = populate
245-
246-
def copy(self, name=None) -> Meta:
247-
if name:
248-
return Meta(name, self.datatype, self.distribution)
249-
return Meta(self.name, self.datatype, self.distribution)

0 commit comments

Comments
 (0)