Skip to content

Commit 6d0e39b

Browse files
Merge branch 'main' into base-causal-test-case
# Conflicts: # causal_testing/generation/abstract_causal_test_case.py # causal_testing/json_front/json_class.py # tests/generation_tests/test_abstract_test_case.py
2 parents 7e70730 + 92ccc5d commit 6d0e39b

File tree

15 files changed

+409
-123
lines changed

15 files changed

+409
-123
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Here are some explanations for the causal inference terminology used above.
3131

3232
## Installation
3333

34-
To use the causal testing framework, clone the repository, `cd` into the root directory, and run `pip install -e .`. More detailled installation instructions can be found in the [online documentation](https://causal-testing-framework.readthedocs.io/en/latest/installation.html).
34+
See the readthedocs site for installation instructions](https://causal-testing-framework.readthedocs.io/en/latest/installation.html).
3535

3636
## Usage
3737

causal_testing/data_collection/data_collector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from abc import ABC, abstractmethod
3+
from enum import Enum
34

45
import pandas as pd
56
import z3
@@ -140,4 +141,7 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
140141
for meta in self.scenario.metas():
141142
meta.populate(execution_data_df)
142143
scenario_execution_data_df = self.filter_valid_data(execution_data_df)
144+
for var_name, var in self.scenario.variables.items():
145+
if issubclass(var.datatype, Enum):
146+
scenario_execution_data_df[var_name] = [var.datatype(x) for x in scenario_execution_data_df[var_name]]
143147
return scenario_execution_data_df

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
import pandas as pd
55
import z3
66
from scipy import stats
7+
import itertools
78

89
from causal_testing.specification.scenario import Scenario
910
from causal_testing.specification.variable import Variable
1011
from causal_testing.testing.causal_test_case import CausalTestCase
1112
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
1213
from causal_testing.testing.base_test_case import BaseTestCase
1314

15+
from enum import Enum
16+
1417
logger = logging.getLogger(__name__)
1518

1619

@@ -25,24 +28,26 @@ def __init__(
2528
self,
2629
scenario: Scenario,
2730
intervention_constraints: set[z3.ExprRef],
28-
treatment_variables: set[Variable],
31+
treatment_variable: Variable,
2932
expected_causal_effect: dict[Variable:CausalTestOutcome],
3033
effect_modifiers: set[Variable] = None,
3134
estimate_type: str = "ate",
35+
effect: str = "total",
3236
):
33-
if treatment_variables not in scenario.variables.values():
37+
if treatment_variable not in scenario.variables.values():
3438
raise ValueError(
3539
"Treatment variables must be a subset of variables."
36-
+ f" Instead got:\ntreatment_variables={treatment_variables}\nvariables={scenario.variables}"
40+
+ f" Instead got:\ntreatment_variables={treatment_variable}\nvariables={scenario.variables}"
3741
)
3842

3943
assert len(expected_causal_effect) == 1, "We currently only support tests with one causal outcome"
4044

4145
self.scenario = scenario
4246
self.intervention_constraints = intervention_constraints
43-
self.treatment_variables = treatment_variables
47+
self.treatment_variable = treatment_variable
4448
self.expected_causal_effect = expected_causal_effect
4549
self.estimate_type = estimate_type
50+
self.effect = effect
4651

4752
if effect_modifiers is not None:
4853
self.effect_modifiers = effect_modifiers
@@ -102,7 +107,12 @@ def _generate_concrete_tests(
102107
for c in self.intervention_constraints:
103108
optimizer.assert_and_track(c, str(c))
104109

105-
optimizer.add_soft([self.scenario.variables[v].z3 == row[v] for v in run_columns])
110+
for v in run_columns:
111+
optimizer.add_soft(
112+
self.scenario.variables[v].z3
113+
== self.scenario.variables[v].z3_val(self.scenario.variables[v].z3, row[v])
114+
)
115+
106116
if optimizer.check() == z3.unsat:
107117
logger.warning(
108118
"Satisfiability of test case was unsat.\n" "Constraints \n %s \n Unsat core %s",
@@ -122,6 +132,7 @@ def _generate_concrete_tests(
122132
expected_causal_effect=list(self.expected_causal_effect.values())[0],
123133
estimate_type=self.estimate_type,
124134
effect_modifier_configuration={v: v.cast(model[v.z3]) for v in self.effect_modifiers},
135+
effect=self.effect,
125136
)
126137

127138
for v in self.scenario.inputs():
@@ -132,19 +143,20 @@ def _generate_concrete_tests(
132143
+ f"{constraints}\nUsing value {v.cast(model[v.z3])} instead in test\n{concrete_test}"
133144
)
134145

135-
concrete_tests.append(concrete_test)
136-
# Control run
137-
control_run = {
138-
v.name: v.cast(model[v.z3]) for v in self.scenario.variables.values() if v.name in run_columns
139-
}
140-
control_run["bin"] = index
141-
runs.append(control_run)
142-
# Treatment run
143-
if rct:
144-
treatment_run = control_run.copy()
145-
treatment_run.update({k.name: v for k, v in concrete_test.treatment_input_configuration.items()})
146-
treatment_run["bin"] = index
147-
runs.append(treatment_run)
146+
if not any([vars(t) == vars(concrete_test) for t in concrete_tests]):
147+
concrete_tests.append(concrete_test)
148+
# Control run
149+
control_run = {
150+
v.name: v.cast(model[v.z3]) for v in self.scenario.variables.values() if v.name in run_columns
151+
}
152+
control_run["bin"] = index
153+
runs.append(control_run)
154+
# Treatment run
155+
if rct:
156+
treatment_run = control_run.copy()
157+
treatment_run.update({k.name: v for k, v in concrete_test.treatment_input_configuration.items()})
158+
treatment_run["bin"] = index
159+
runs.append(treatment_run)
148160

149161
return concrete_tests, pd.DataFrame(runs, columns=run_columns + ["bin"])
150162

@@ -180,9 +192,12 @@ def generate_concrete_tests(
180192
runs = pd.DataFrame()
181193
ks_stats = []
182194

195+
pre_break = False
183196
for i in range(hard_max):
184197
concrete_tests_, runs_ = self._generate_concrete_tests(sample_size, rct, seed + i)
185-
concrete_tests += concrete_tests_
198+
for t_ in concrete_tests_:
199+
if not any([vars(t_) == vars(t) for t in concrete_tests]):
200+
concrete_tests.append(t_)
186201
runs = pd.concat([runs, runs_])
187202
assert concrete_tests_ not in concrete_tests, "Duplicate entries unlikely unless something went wrong"
188203

@@ -209,14 +224,32 @@ def generate_concrete_tests(
209224
for var in effect_modifier_configs.columns
210225
}
211226
)
212-
if target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
227+
control_values = [test.control_input_configuration[self.treatment_variable] for test in concrete_tests]
228+
treatment_values = [test.treatment_input_configuration[self.treatment_variable] for test in concrete_tests]
229+
230+
if self.treatment_variable.datatype is bool and set([(True, False), (False, True)]).issubset(
231+
set(zip(control_values, treatment_values))
232+
):
233+
pre_break = True
234+
break
235+
if issubclass(self.treatment_variable.datatype, Enum) and set(
236+
{
237+
(x, y)
238+
for x, y in itertools.product(self.treatment_variable.datatype, self.treatment_variable.datatype)
239+
if x != y
240+
}
241+
).issubset(zip(control_values, treatment_values)):
242+
pre_break = True
243+
break
244+
elif target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
245+
pre_break = True
213246
break
214247

215-
if target_ks_score is not None and not all((stat <= target_ks_score for stat in ks_stats.values())):
248+
if target_ks_score is not None and not pre_break:
216249
logger.error(
217-
"Hard max of %s reached but could not achieve target ks_score of %s. Got %s.",
218-
hard_max,
250+
"Hard max reached but could not achieve target ks_score of %s. Got %s. Generated %s distinct tests",
219251
target_ks_score,
220252
ks_stats,
253+
len(concrete_tests),
221254
)
222255
return concrete_tests, runs

causal_testing/json_front/json_class.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def set_variables(self, inputs: dict, outputs: dict, metas: dict):
7575
:param metas:
7676
"""
7777
self.inputs = [Input(i["name"], i["type"], i["distribution"]) for i in inputs]
78-
self.outputs = [Output(i["name"], i["type"]) for i in outputs]
78+
self.outputs = [Output(i["name"], i["type"], i.get("distribution", None)) for i in outputs]
7979
self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else []
8080

8181
def setup(self):
@@ -89,11 +89,11 @@ def setup(self):
8989
self._populate_metas()
9090

9191
def _create_abstract_test_case(self, test, mutates, effects):
92-
92+
assert len(test["mutations"]) == 1
9393
abstract_test = AbstractCausalTestCase(
9494
scenario=self.modelling_scenario,
9595
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
96-
treatment_variables=self.modelling_scenario.variables[next(iter(test["mutations"]))],
96+
treatment_variable=next(self.modelling_scenario.variables[v] for v in test["mutations"]),
9797
expected_causal_effect={
9898
self.modelling_scenario.variables[variable]: effects[effect]
9999
for variable, effect in test["expectedEffect"].items()
@@ -102,6 +102,7 @@ def _create_abstract_test_case(self, test, mutates, effects):
102102
if "effect_modifiers" in test
103103
else {},
104104
estimate_type=test["estimate_type"],
105+
effect=test.get("effect", "total"),
105106
)
106107
return abstract_test
107108

@@ -125,7 +126,7 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
125126
logger.info([abstract_test.treatment_variables.name, abstract_test.treatment_variables.distribution])
126127
logger.info("Number of concrete tests for test case: %s", str(len(concrete_tests)))
127128
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
128-
logger.info("%s/%s failed", failures, len(concrete_tests))
129+
logger.info("%s/%s failed for %s\n", failures, len(concrete_tests), test["name"])
129130

130131
def _execute_tests(self, concrete_tests, estimators, test, f_flag):
131132
failures = 0
@@ -152,11 +153,12 @@ def _populate_metas(self):
152153
meta.populate(self.data)
153154

154155
for var in self.metas + self.outputs:
155-
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
156-
fitter.fit()
157-
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
158-
var.distribution = getattr(scipy.stats, dist)(**params)
159-
logger.info(var.name + f"{dist}({params})")
156+
if not var.distribution:
157+
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
158+
fitter.fit()
159+
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
160+
var.distribution = getattr(scipy.stats, dist)(**params)
161+
logger.info(var.name + f" {dist}({params})")
160162

161163
def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estimator, f_flag: bool) -> bool:
162164
"""Executes a singular test case, prints the results and returns the test case result
@@ -179,19 +181,15 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
179181
if causal_test_result.ci_low() and causal_test_result.ci_high():
180182
result_string = f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < {causal_test_result.ci_high()}"
181183
else:
182-
result_string = causal_test_result.test_value.value
184+
result_string = f"{causal_test_result.test_value.value} no confidence intervals"
183185
if f_flag:
184186
assert test_passes, (
185187
f"{causal_test_case}\n FAILED - expected {causal_test_case.expected_causal_effect}, "
186188
f"got {result_string}"
187189
)
188190
if not test_passes:
189191
failed = True
190-
logger.warning(
191-
" FAILED- expected %s, got %s",
192-
causal_test_case.expected_causal_effect,
193-
causal_test_result.test_value.value,
194-
)
192+
logger.warning(" FAILED- expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
195193
return failed
196194

197195
def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) -> tuple[CausalTestEngine, Estimator]:

causal_testing/specification/causal_dag.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st
255255
gam.add_edges_from(edges_to_add)
256256

257257
min_seps = list(list_all_min_sep(gam, "TREATMENT", "OUTCOME", set(treatments), set(outcomes)))
258-
# min_seps.remove(set(outcomes))
258+
if set(outcomes) in min_seps:
259+
min_seps.remove(set(outcomes))
259260
return min_seps
260261

261262
def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
@@ -278,6 +279,7 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
278279
:param outcomes: A list of strings representing outcomes.
279280
:return: A list of strings representing the minimal adjustment set.
280281
"""
282+
281283
# 1. Construct the proper back-door graph's ancestor moral graph
282284
proper_backdoor_graph = self.get_proper_backdoor_graph(treatments, outcomes)
283285
ancestor_proper_backdoor_graph = proper_backdoor_graph.get_ancestor_graph(treatments, outcomes)
@@ -316,6 +318,7 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
316318
for adj in minimum_adjustment_sets
317319
if self.constructive_backdoor_criterion(proper_backdoor_graph, treatments, outcomes, adj)
318320
]
321+
319322
return valid_minimum_adjustment_sets
320323

321324
def adjustment_set_is_minimal(self, treatments: list[str], outcomes: list[str], adjustment_set: set[str]) -> bool:

causal_testing/specification/variable.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import lhsmdu
77
from pandas import DataFrame
88
from scipy.stats._distn_infrastructure import rv_generic
9-
from z3 import Bool, BoolRef, Const, EnumSort, Int, RatNumRef, Real, String
9+
from z3 import Bool, BoolRef, Const, EnumSort, Int, RatNumRef, Real, String, DatatypeRef
1010

1111
# Declare type variable
1212
# Is there a better way? I'd really like to do Variable[T](ExprRef)
@@ -22,7 +22,7 @@ def z3_types(datatype):
2222
if datatype in types:
2323
return types[datatype]
2424
if issubclass(datatype, Enum):
25-
dtype, _ = EnumSort(datatype.__name__, [x.name for x in datatype])
25+
dtype, _ = EnumSort(datatype.__name__, [str(x.value) for x in datatype])
2626
return lambda x: Const(x, dtype)
2727
if hasattr(datatype, "to_z3"):
2828
return datatype.to_z3()
@@ -153,19 +153,27 @@ def cast(self, val: Any) -> T:
153153
:rtype: T
154154
"""
155155
assert val is not None, f"Invalid value None for variable {self}"
156+
if isinstance(val, self.datatype):
157+
return val
158+
if isinstance(val, BoolRef) and self.datatype == bool:
159+
return str(val) == "True"
156160
if isinstance(val, RatNumRef) and self.datatype == float:
157161
return float(val.numerator().as_long() / val.denominator().as_long())
158162
if hasattr(val, "is_string_value") and val.is_string_value() and self.datatype == str:
159163
return val.as_string()
160-
if (isinstance(val, float) or isinstance(val, int)) and (self.datatype == int or self.datatype == float):
164+
if (isinstance(val, float) or isinstance(val, int) or isinstance(val, bool)) and (
165+
self.datatype == int or self.datatype == float or self.datatype == bool
166+
):
161167
return self.datatype(val)
168+
if issubclass(self.datatype, Enum) and isinstance(val, DatatypeRef):
169+
return self.datatype(str(val))
162170
return self.datatype(str(val))
163171

164172
def z3_val(self, z3_var, val: Any) -> T:
165173
native_val = self.cast(val)
166174
if isinstance(native_val, Enum):
167175
values = [z3_var.sort().constructor(c)() for c in range(z3_var.sort().num_constructors())]
168-
values = [v for v in values if str(v) == str(val)]
176+
values = [v for v in values if val.__class__(str(v)) == val]
169177
assert len(values) == 1, f"Expected {values} to be length 1"
170178
return values[0]
171179
return native_val
@@ -193,7 +201,6 @@ def typestring(self) -> str:
193201
"""
194202
return type(self).__name__
195203

196-
@abstractmethod
197204
def copy(self, name: str = None) -> Variable:
198205
"""Return a new instance of the Variable with the given name, or with
199206
the original name if no name is supplied.
@@ -203,26 +210,18 @@ def copy(self, name: str = None) -> Variable:
203210
:rtype: Variable
204211
205212
"""
206-
raise NotImplementedError("Method `copy` must be instantiated.")
213+
if name:
214+
return self.__class__(name, self.datatype, self.distribution)
215+
return self.__class__(self.name, self.datatype, self.distribution)
207216

208217

209218
class Input(Variable):
210219
"""An extension of the Variable class representing inputs."""
211220

212-
def copy(self, name=None) -> Input:
213-
if name:
214-
return Input(name, self.datatype, self.distribution)
215-
return Input(self.name, self.datatype, self.distribution)
216-
217221

218222
class Output(Variable):
219223
"""An extension of the Variable class representing outputs."""
220224

221-
def copy(self, name=None) -> Output:
222-
if name:
223-
return Output(name, self.datatype, self.distribution)
224-
return Output(self.name, self.datatype, self.distribution)
225-
226225

227226
class Meta(Variable):
228227
"""An extension of the Variable class representing metavariables. These are variables which are relevant to the
@@ -242,8 +241,3 @@ class Meta(Variable):
242241
def __init__(self, name: str, datatype: T, populate: Callable[[DataFrame], DataFrame]):
243242
super().__init__(name, datatype)
244243
self.populate = populate
245-
246-
def copy(self, name=None) -> Meta:
247-
if name:
248-
return Meta(name, self.datatype, self.distribution)
249-
return Meta(self.name, self.datatype, self.distribution)

0 commit comments

Comments
 (0)