Skip to content

Commit fd7c43a

Browse files
Merge branch 'base-causal-test-case' into causal_test_case_refactor
2 parents 60d7bee + 92b6250 commit fd7c43a

28 files changed

+609
-262
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Here are some explanations for the causal inference terminology used above.
3131

3232
## Installation
3333

34-
To use the causal testing framework, clone the repository, `cd` into the root directory, and run `pip install -e .`. More detailled installation instructions can be found in the [online documentation](https://causal-testing-framework.readthedocs.io/en/latest/installation.html).
34+
See the readthedocs site for installation instructions](https://causal-testing-framework.readthedocs.io/en/latest/installation.html).
3535

3636
## Usage
3737

causal_testing/data_collection/data_collector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from abc import ABC, abstractmethod
3+
from enum import Enum
34

45
import pandas as pd
56
import z3
@@ -140,4 +141,7 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
140141
for meta in self.scenario.metas():
141142
meta.populate(execution_data_df)
142143
scenario_execution_data_df = self.filter_valid_data(execution_data_df)
144+
for var_name, var in self.scenario.variables.items():
145+
if issubclass(var.datatype, Enum):
146+
scenario_execution_data_df[var_name] = [var.datatype(x) for x in scenario_execution_data_df[var_name]]
143147
return scenario_execution_data_df

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 68 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
import pandas as pd
55
import z3
66
from scipy import stats
7+
import itertools
78

89
from causal_testing.specification.scenario import Scenario
910
from causal_testing.specification.variable import Variable
1011
from causal_testing.testing.causal_test_case import CausalTestCase
1112
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
1213
from causal_testing.testing.base_test_case import BaseTestCase
1314

15+
from enum import Enum
16+
1417
logger = logging.getLogger(__name__)
1518

1619

@@ -25,23 +28,26 @@ def __init__(
2528
self,
2629
scenario: Scenario,
2730
intervention_constraints: set[z3.ExprRef],
28-
treatment_variables: set[Variable],
31+
treatment_variable: Variable,
2932
expected_causal_effect: dict[Variable:CausalTestOutcome],
3033
effect_modifiers: set[Variable] = None,
3134
estimate_type: str = "ate",
35+
effect: str = "total",
3236
):
33-
assert {treatment_variables}.issubset(scenario.variables.values()), (
34-
"Treatment variables must be a subset of variables."
35-
+ f" Instead got:\ntreatment_variables={treatment_variables}\nvariables={scenario.variables}"
36-
)
37+
if treatment_variable not in scenario.variables.values():
38+
raise ValueError(
39+
"Treatment variables must be a subset of variables."
40+
+ f" Instead got:\ntreatment_variables={treatment_variable}\nvariables={scenario.variables}"
41+
)
3742

3843
assert len(expected_causal_effect) == 1, "We currently only support tests with one causal outcome"
3944

4045
self.scenario = scenario
4146
self.intervention_constraints = intervention_constraints
42-
self.treatment_variables = treatment_variables
47+
self.treatment_variable = treatment_variable
4348
self.expected_causal_effect = expected_causal_effect
4449
self.estimate_type = estimate_type
50+
self.effect = effect
4551

4652
if effect_modifiers is not None:
4753
self.effect_modifiers = effect_modifiers
@@ -101,7 +107,12 @@ def _generate_concrete_tests(
101107
for c in self.intervention_constraints:
102108
optimizer.assert_and_track(c, str(c))
103109

104-
optimizer.add_soft([self.scenario.variables[v].z3 == row[v] for v in run_columns])
110+
for v in run_columns:
111+
optimizer.add_soft(
112+
self.scenario.variables[v].z3
113+
== self.scenario.variables[v].z3_val(self.scenario.variables[v].z3, row[v])
114+
)
115+
105116
if optimizer.check() == z3.unsat:
106117
logger.warning(
107118
"Satisfiability of test case was unsat.\n" "Constraints \n %s \n Unsat core %s",
@@ -110,13 +121,17 @@ def _generate_concrete_tests(
110121
)
111122
model = optimizer.model()
112123

113-
base_test_case = BaseTestCase(self.treatment_variables, list(self.expected_causal_effect.keys())[0])
124+
base_test_case = BaseTestCase(
125+
treatment_variable=self.treatment_variable,
126+
outcome_variable=list(self.expected_causal_effect.keys())[0],
127+
effect=self.effect,
128+
)
114129

115130
concrete_test = CausalTestCase(
116131
base_test_case=base_test_case,
117-
control_value=self.treatment_variables.cast(model[self.treatment_variables.z3]),
118-
treatment_value=self.treatment_variables.cast(
119-
model[self.scenario.treatment_variables[self.treatment_variables.name].z3]
132+
control_value=self.treatment_variable.cast(model[self.treatment_variable.z3]),
133+
treatment_value=self.treatment_variable.cast(
134+
model[self.scenario.treatment_variables[self.treatment_variable.name].z3]
120135
),
121136
expected_causal_effect=list(self.expected_causal_effect.values())[0],
122137
estimate_type=self.estimate_type,
@@ -131,20 +146,22 @@ def _generate_concrete_tests(
131146
+ f"{constraints}\nUsing value {v.cast(model[v.z3])} instead in test\n{concrete_test}"
132147
)
133148

134-
concrete_tests.append(concrete_test)
135-
# Control run
136-
control_run = {
137-
v.name: v.cast(model[v.z3]) for v in self.scenario.variables.values() if v.name in run_columns
138-
}
139-
control_run["bin"] = index
140-
runs.append(control_run)
141-
# Treatment run
142-
if rct:
143-
treatment_run = control_run.copy()
144-
treatment_run.update({concrete_test.treatment_variable.name: concrete_test.treatment_value})
145-
# treatment_run.update({k.name: v for k, v in concrete_test.treatment_input_configuration.items()})
146-
treatment_run["bin"] = index
147-
runs.append(treatment_run)
149+
150+
if not any([vars(t) == vars(concrete_test) for t in concrete_tests]):
151+
concrete_tests.append(concrete_test)
152+
# Control run
153+
control_run = {
154+
v.name: v.cast(model[v.z3]) for v in self.scenario.variables.values() if v.name in run_columns
155+
}
156+
control_run["bin"] = index
157+
runs.append(control_run)
158+
# Treatment run
159+
if rct:
160+
treatment_run = control_run.copy()
161+
treatment_run.update({concrete_test.treatment_variable.name: concrete_test.treatment_value})
162+
treatment_run["bin"] = index
163+
runs.append(treatment_run)
164+
148165

149166
return concrete_tests, pd.DataFrame(runs, columns=run_columns + ["bin"])
150167

@@ -180,9 +197,12 @@ def generate_concrete_tests(
180197
runs = pd.DataFrame()
181198
ks_stats = []
182199

200+
pre_break = False
183201
for i in range(hard_max):
184202
concrete_tests_, runs_ = self._generate_concrete_tests(sample_size, rct, seed + i)
185-
concrete_tests += concrete_tests_
203+
for t_ in concrete_tests_:
204+
if not any([vars(t_) == vars(t) for t in concrete_tests]):
205+
concrete_tests.append(t_)
186206
runs = pd.concat([runs, runs_])
187207
assert concrete_tests_ not in concrete_tests, "Duplicate entries unlikely unless something went wrong"
188208

@@ -209,14 +229,32 @@ def generate_concrete_tests(
209229
for var in effect_modifier_configs.columns
210230
}
211231
)
212-
if target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
232+
control_values = [test.control_input_configuration[self.treatment_variable] for test in concrete_tests]
233+
treatment_values = [test.treatment_input_configuration[self.treatment_variable] for test in concrete_tests]
234+
235+
if self.treatment_variable.datatype is bool and set([(True, False), (False, True)]).issubset(
236+
set(zip(control_values, treatment_values))
237+
):
238+
pre_break = True
239+
break
240+
if issubclass(self.treatment_variable.datatype, Enum) and set(
241+
{
242+
(x, y)
243+
for x, y in itertools.product(self.treatment_variable.datatype, self.treatment_variable.datatype)
244+
if x != y
245+
}
246+
).issubset(zip(control_values, treatment_values)):
247+
pre_break = True
248+
break
249+
elif target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
250+
pre_break = True
213251
break
214252

215-
if target_ks_score is not None and not all((stat <= target_ks_score for stat in ks_stats.values())):
253+
if target_ks_score is not None and not pre_break:
216254
logger.error(
217-
"Hard max of %s reached but could not achieve target ks_score of %s. Got %s.",
218-
hard_max,
255+
"Hard max reached but could not achieve target ks_score of %s. Got %s. Generated %s distinct tests",
219256
target_ks_score,
220257
ks_stats,
258+
len(concrete_tests),
221259
)
222260
return concrete_tests, runs

causal_testing/json_front/json_class.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def set_variables(self, inputs: dict, outputs: dict, metas: dict):
7575
:param metas:
7676
"""
7777
self.inputs = [Input(i["name"], i["type"], i["distribution"]) for i in inputs]
78-
self.outputs = [Output(i["name"], i["type"]) for i in outputs]
78+
self.outputs = [Output(i["name"], i["type"], i.get("distribution", None)) for i in outputs]
7979
self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else []
8080

8181
def setup(self):
@@ -89,11 +89,11 @@ def setup(self):
8989
self._populate_metas()
9090

9191
def _create_abstract_test_case(self, test, mutates, effects):
92-
92+
assert len(test["mutations"]) == 1
9393
abstract_test = AbstractCausalTestCase(
9494
scenario=self.modelling_scenario,
9595
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
96-
treatment_variables=self.modelling_scenario.variables[next(iter(test["mutations"]))],
96+
treatment_variable=next(self.modelling_scenario.variables[v] for v in test["mutations"]),
9797
expected_causal_effect={
9898
self.modelling_scenario.variables[variable]: effects[effect]
9999
for variable, effect in test["expectedEffect"].items()
@@ -102,6 +102,7 @@ def _create_abstract_test_case(self, test, mutates, effects):
102102
if "effect_modifiers" in test
103103
else {},
104104
estimate_type=test["estimate_type"],
105+
effect=test.get("effect", "total"),
105106
)
106107
return abstract_test
107108

@@ -122,10 +123,10 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
122123
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
123124
logger.info("Executing test: %s", test["name"])
124125
logger.info(abstract_test)
125-
logger.info([abstract_test.treatment_variables.name, abstract_test.treatment_variables.distribution])
126+
logger.info([abstract_test.treatment_variable.name, abstract_test.treatment_variable.distribution])
126127
logger.info("Number of concrete tests for test case: %s", str(len(concrete_tests)))
127128
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
128-
logger.info("%s/%s failed", failures, len(concrete_tests))
129+
logger.info("%s/%s failed for %s\n", failures, len(concrete_tests), test["name"])
129130

130131
def _execute_tests(self, concrete_tests, estimators, test, f_flag):
131132
failures = 0
@@ -152,11 +153,12 @@ def _populate_metas(self):
152153
meta.populate(self.data)
153154

154155
for var in self.metas + self.outputs:
155-
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
156-
fitter.fit()
157-
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
158-
var.distribution = getattr(scipy.stats, dist)(**params)
159-
logger.info(var.name + f"{dist}({params})")
156+
if not var.distribution:
157+
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
158+
fitter.fit()
159+
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
160+
var.distribution = getattr(scipy.stats, dist)(**params)
161+
logger.info(var.name + f" {dist}({params})")
160162

161163
def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estimator, f_flag: bool) -> bool:
162164
"""Executes a singular test case, prints the results and returns the test case result
@@ -177,23 +179,17 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
177179

178180
result_string = str()
179181
if causal_test_result.ci_low() and causal_test_result.ci_high():
180-
result_string = (
181-
f"{causal_test_result.ci_low()} < {causal_test_result.ate} < {causal_test_result.ci_high()}"
182-
)
182+
result_string = f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < {causal_test_result.ci_high()}"
183183
else:
184-
result_string = causal_test_result.ate
184+
result_string = f"{causal_test_result.test_value.value} no confidence intervals"
185185
if f_flag:
186186
assert test_passes, (
187187
f"{causal_test_case}\n FAILED - expected {causal_test_case.expected_causal_effect}, "
188188
f"got {result_string}"
189189
)
190190
if not test_passes:
191191
failed = True
192-
logger.warning(
193-
" FAILED- expected %s, got %s",
194-
causal_test_case.expected_causal_effect,
195-
causal_test_result.ate,
196-
)
192+
logger.warning(" FAILED- expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
197193
return failed
198194

199195
def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) -> tuple[CausalTestEngine, Estimator]:

causal_testing/specification/causal_dag.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st
255255
gam.add_edges_from(edges_to_add)
256256

257257
min_seps = list(list_all_min_sep(gam, "TREATMENT", "OUTCOME", set(treatments), set(outcomes)))
258-
# min_seps.remove(set(outcomes))
258+
if set(outcomes) in min_seps:
259+
min_seps.remove(set(outcomes))
259260
return min_seps
260261

261262
def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
@@ -278,6 +279,7 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
278279
:param outcomes: A list of strings representing outcomes.
279280
:return: A list of strings representing the minimal adjustment set.
280281
"""
282+
281283
# 1. Construct the proper back-door graph's ancestor moral graph
282284
proper_backdoor_graph = self.get_proper_backdoor_graph(treatments, outcomes)
283285
ancestor_proper_backdoor_graph = proper_backdoor_graph.get_ancestor_graph(treatments, outcomes)
@@ -316,6 +318,7 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
316318
for adj in minimum_adjustment_sets
317319
if self.constructive_backdoor_criterion(proper_backdoor_graph, treatments, outcomes, adj)
318320
]
321+
319322
return valid_minimum_adjustment_sets
320323

321324
def adjustment_set_is_minimal(self, treatments: list[str], outcomes: list[str], adjustment_set: set[str]) -> bool:

0 commit comments

Comments
 (0)