Skip to content

Commit ca2b630

Browse files
Merge branch 'main' into PyPI_setup
2 parents b270434 + 1098043 commit ca2b630

33 files changed

+1110
-375
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Here are some explanations for the causal inference terminology used above.
3131

3232
## Installation
3333

34-
To use the causal testing framework, clone the repository, `cd` into the root directory, and run `pip install -e .`. More detailled installation instructions can be found in the [online documentation](https://causal-testing-framework.readthedocs.io/en/latest/installation.html).
34+
See the readthedocs site for installation instructions](https://causal-testing-framework.readthedocs.io/en/latest/installation.html).
3535

3636
## Usage
3737

causal_testing/data_collection/data_collector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from abc import ABC, abstractmethod
3+
from enum import Enum
34

45
import pandas as pd
56
import z3
@@ -140,4 +141,7 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
140141
for meta in self.scenario.metas():
141142
meta.populate(execution_data_df)
142143
scenario_execution_data_df = self.filter_valid_data(execution_data_df)
144+
for var_name, var in self.scenario.variables.items():
145+
if issubclass(var.datatype, Enum):
146+
scenario_execution_data_df[var_name] = [var.datatype(x) for x in scenario_execution_data_df[var_name]]
143147
return scenario_execution_data_df

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 72 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,22 @@
44
import pandas as pd
55
import z3
66
from scipy import stats
7+
import itertools
78

89
from causal_testing.specification.scenario import Scenario
910
from causal_testing.specification.variable import Variable
1011
from causal_testing.testing.causal_test_case import CausalTestCase
1112
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
13+
from causal_testing.testing.base_test_case import BaseTestCase
14+
15+
from enum import Enum
1216

1317
logger = logging.getLogger(__name__)
1418

1519

1620
class AbstractCausalTestCase:
1721
"""
18-
An abstract test case serves as a generator for concrete test cases. Instead of having concrete conctrol
22+
An abstract test case serves as a generator for concrete test cases. Instead of having concrete control
1923
and treatment values, we instead just specify the intervention and the treatment variables. This then
2024
enables potentially infinite concrete test cases to be generated between different values of the treatment.
2125
"""
@@ -24,23 +28,26 @@ def __init__(
2428
self,
2529
scenario: Scenario,
2630
intervention_constraints: set[z3.ExprRef],
27-
treatment_variables: set[Variable],
31+
treatment_variable: Variable,
2832
expected_causal_effect: dict[Variable:CausalTestOutcome],
2933
effect_modifiers: set[Variable] = None,
3034
estimate_type: str = "ate",
35+
effect: str = "total",
3136
):
32-
assert treatment_variables.issubset(scenario.variables.values()), (
33-
"Treatment variables must be a subset of variables."
34-
+ f" Instead got:\ntreatment_variables={treatment_variables}\nvariables={scenario.variables}"
35-
)
37+
if treatment_variable not in scenario.variables.values():
38+
raise ValueError(
39+
"Treatment variables must be a subset of variables."
40+
+ f" Instead got:\ntreatment_variables={treatment_variable}\nvariables={scenario.variables}"
41+
)
3642

3743
assert len(expected_causal_effect) == 1, "We currently only support tests with one causal outcome"
3844

3945
self.scenario = scenario
4046
self.intervention_constraints = intervention_constraints
41-
self.treatment_variables = treatment_variables
47+
self.treatment_variable = treatment_variable
4248
self.expected_causal_effect = expected_causal_effect
4349
self.estimate_type = estimate_type
50+
self.effect = effect
4451

4552
if effect_modifiers is not None:
4653
self.effect_modifiers = effect_modifiers
@@ -100,7 +107,12 @@ def _generate_concrete_tests(
100107
for c in self.intervention_constraints:
101108
optimizer.assert_and_track(c, str(c))
102109

103-
optimizer.add_soft([self.scenario.variables[v].z3 == row[v] for v in run_columns])
110+
for v in run_columns:
111+
optimizer.add_soft(
112+
self.scenario.variables[v].z3
113+
== self.scenario.variables[v].z3_val(self.scenario.variables[v].z3, row[v])
114+
)
115+
104116
if optimizer.check() == z3.unsat:
105117
logger.warning(
106118
"Satisfiability of test case was unsat.\n" "Constraints \n %s \n Unsat core %s",
@@ -109,13 +121,19 @@ def _generate_concrete_tests(
109121
)
110122
model = optimizer.model()
111123

124+
base_test_case = BaseTestCase(
125+
treatment_variable=self.treatment_variable,
126+
outcome_variable=list(self.expected_causal_effect.keys())[0],
127+
effect=self.effect,
128+
)
129+
112130
concrete_test = CausalTestCase(
113-
control_input_configuration={v: v.cast(model[v.z3]) for v in self.treatment_variables},
114-
treatment_input_configuration={
115-
v: v.cast(model[self.scenario.treatment_variables[v.name].z3]) for v in self.treatment_variables
116-
},
131+
base_test_case=base_test_case,
132+
control_value=self.treatment_variable.cast(model[self.treatment_variable.z3]),
133+
treatment_value=self.treatment_variable.cast(
134+
model[self.scenario.treatment_variables[self.treatment_variable.name].z3]
135+
),
117136
expected_causal_effect=list(self.expected_causal_effect.values())[0],
118-
outcome_variables=list(self.expected_causal_effect.keys()),
119137
estimate_type=self.estimate_type,
120138
effect_modifier_configuration={v: v.cast(model[v.z3]) for v in self.effect_modifiers},
121139
)
@@ -128,19 +146,20 @@ def _generate_concrete_tests(
128146
+ f"{constraints}\nUsing value {v.cast(model[v.z3])} instead in test\n{concrete_test}"
129147
)
130148

131-
concrete_tests.append(concrete_test)
132-
# Control run
133-
control_run = {
134-
v.name: v.cast(model[v.z3]) for v in self.scenario.variables.values() if v.name in run_columns
135-
}
136-
control_run["bin"] = index
137-
runs.append(control_run)
138-
# Treatment run
139-
if rct:
140-
treatment_run = control_run.copy()
141-
treatment_run.update({k.name: v for k, v in concrete_test.treatment_input_configuration.items()})
142-
treatment_run["bin"] = index
143-
runs.append(treatment_run)
149+
if not any([vars(t) == vars(concrete_test) for t in concrete_tests]):
150+
concrete_tests.append(concrete_test)
151+
# Control run
152+
control_run = {
153+
v.name: v.cast(model[v.z3]) for v in self.scenario.variables.values() if v.name in run_columns
154+
}
155+
control_run["bin"] = index
156+
runs.append(control_run)
157+
# Treatment run
158+
if rct:
159+
treatment_run = control_run.copy()
160+
treatment_run.update({concrete_test.treatment_variable.name: concrete_test.treatment_value})
161+
treatment_run["bin"] = index
162+
runs.append(treatment_run)
144163

145164
return concrete_tests, pd.DataFrame(runs, columns=run_columns + ["bin"])
146165

@@ -176,13 +195,16 @@ def generate_concrete_tests(
176195
runs = pd.DataFrame()
177196
ks_stats = []
178197

198+
pre_break = False
179199
for i in range(hard_max):
180200
concrete_tests_, runs_ = self._generate_concrete_tests(sample_size, rct, seed + i)
181-
concrete_tests += concrete_tests_
201+
for t_ in concrete_tests_:
202+
if not any([vars(t_) == vars(t) for t in concrete_tests]):
203+
concrete_tests.append(t_)
182204
runs = pd.concat([runs, runs_])
183205
assert concrete_tests_ not in concrete_tests, "Duplicate entries unlikely unless something went wrong"
184206

185-
control_configs = pd.DataFrame([test.control_input_configuration for test in concrete_tests])
207+
control_configs = pd.DataFrame([{test.treatment_variable: test.control_value} for test in concrete_tests])
186208
ks_stats = {
187209
var: stats.kstest(control_configs[var], var.distribution.cdf).statistic
188210
for var in control_configs.columns
@@ -205,14 +227,32 @@ def generate_concrete_tests(
205227
for var in effect_modifier_configs.columns
206228
}
207229
)
208-
if target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
230+
control_values = [test.control_value for test in concrete_tests]
231+
treatment_values = [test.treatment_value for test in concrete_tests]
232+
233+
if self.treatment_variable.datatype is bool and set([(True, False), (False, True)]).issubset(
234+
set(zip(control_values, treatment_values))
235+
):
236+
pre_break = True
237+
break
238+
if issubclass(self.treatment_variable.datatype, Enum) and set(
239+
{
240+
(x, y)
241+
for x, y in itertools.product(self.treatment_variable.datatype, self.treatment_variable.datatype)
242+
if x != y
243+
}
244+
).issubset(zip(control_values, treatment_values)):
245+
pre_break = True
246+
break
247+
elif target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
248+
pre_break = True
209249
break
210250

211-
if target_ks_score is not None and not all((stat <= target_ks_score for stat in ks_stats.values())):
251+
if target_ks_score is not None and not pre_break:
212252
logger.error(
213-
"Hard max of %s reached but could not achieve target ks_score of %s. Got %s.",
214-
hard_max,
253+
"Hard max reached but could not achieve target ks_score of %s. Got %s. Generated %s distinct tests",
215254
target_ks_score,
216255
ks_stats,
256+
len(concrete_tests),
217257
)
218258
return concrete_tests, runs

causal_testing/json_front/json_class.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def set_variables(self, inputs: dict, outputs: dict, metas: dict):
7575
:param metas:
7676
"""
7777
self.inputs = [Input(i["name"], i["type"], i["distribution"]) for i in inputs]
78-
self.outputs = [Output(i["name"], i["type"]) for i in outputs]
78+
self.outputs = [Output(i["name"], i["type"], i.get("distribution", None)) for i in outputs]
7979
self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else []
8080

8181
def setup(self):
@@ -89,10 +89,11 @@ def setup(self):
8989
self._populate_metas()
9090

9191
def _create_abstract_test_case(self, test, mutates, effects):
92+
assert len(test["mutations"]) == 1
9293
abstract_test = AbstractCausalTestCase(
9394
scenario=self.modelling_scenario,
9495
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
95-
treatment_variables={self.modelling_scenario.variables[v] for v in test["mutations"]},
96+
treatment_variable=next(self.modelling_scenario.variables[v] for v in test["mutations"]),
9697
expected_causal_effect={
9798
self.modelling_scenario.variables[variable]: effects[effect]
9899
for variable, effect in test["expectedEffect"].items()
@@ -101,6 +102,7 @@ def _create_abstract_test_case(self, test, mutates, effects):
101102
if "effect_modifiers" in test
102103
else {},
103104
estimate_type=test["estimate_type"],
105+
effect=test.get("effect", "total"),
104106
)
105107
return abstract_test
106108

@@ -121,10 +123,10 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
121123
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
122124
logger.info("Executing test: %s", test["name"])
123125
logger.info(abstract_test)
124-
logger.info([(v.name, v.distribution) for v in abstract_test.treatment_variables])
126+
logger.info([abstract_test.treatment_variable.name, abstract_test.treatment_variable.distribution])
125127
logger.info("Number of concrete tests for test case: %s", str(len(concrete_tests)))
126128
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
127-
logger.info("%s/%s failed", failures, len(concrete_tests))
129+
logger.info("%s/%s failed for %s\n", failures, len(concrete_tests), test["name"])
128130

129131
def _execute_tests(self, concrete_tests, estimators, test, f_flag):
130132
failures = 0
@@ -151,11 +153,12 @@ def _populate_metas(self):
151153
meta.populate(self.data)
152154

153155
for var in self.metas + self.outputs:
154-
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
155-
fitter.fit()
156-
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
157-
var.distribution = getattr(scipy.stats, dist)(**params)
158-
logger.info(var.name + f"{dist}({params})")
156+
if not var.distribution:
157+
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
158+
fitter.fit()
159+
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
160+
var.distribution = getattr(scipy.stats, dist)(**params)
161+
logger.info(var.name + f" {dist}({params})")
159162

160163
def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estimator, f_flag: bool) -> bool:
161164
"""Executes a singular test case, prints the results and returns the test case result
@@ -178,19 +181,15 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
178181
if causal_test_result.ci_low() and causal_test_result.ci_high():
179182
result_string = f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < {causal_test_result.ci_high()}"
180183
else:
181-
result_string = causal_test_result.test_value.value
184+
result_string = f"{causal_test_result.test_value.value} no confidence intervals"
182185
if f_flag:
183186
assert test_passes, (
184187
f"{causal_test_case}\n FAILED - expected {causal_test_case.expected_causal_effect}, "
185188
f"got {result_string}"
186189
)
187190
if not test_passes:
188191
failed = True
189-
logger.warning(
190-
" FAILED- expected %s, got %s",
191-
causal_test_case.expected_causal_effect,
192-
causal_test_result.test_value.value,
193-
)
192+
logger.warning(" FAILED- expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
194193
return failed
195194

196195
def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) -> tuple[CausalTestEngine, Estimator]:
@@ -202,15 +201,15 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
202201
"""
203202
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data_path)
204203
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
205-
causal_test_engine.identification(causal_test_case)
206-
treatment_vars = list(causal_test_case.treatment_input_configuration)
207-
minimal_adjustment_set = causal_test_engine.minimal_adjustment_set - {v.name for v in treatment_vars}
204+
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
205+
treatment_var = causal_test_case.treatment_variable
206+
minimal_adjustment_set = minimal_adjustment_set - {treatment_var}
208207
estimation_model = estimator(
209-
(list(treatment_vars)[0].name,),
210-
[causal_test_case.treatment_input_configuration[v] for v in treatment_vars][0],
211-
[causal_test_case.control_input_configuration[v] for v in treatment_vars][0],
208+
(treatment_var.name,),
209+
causal_test_case.treatment_value,
210+
causal_test_case.control_value,
212211
minimal_adjustment_set,
213-
(list(causal_test_case.outcome_variables)[0].name,),
212+
(causal_test_case.outcome_variable.name,),
214213
causal_test_engine.scenario_execution_data_df,
215214
effect_modifiers=causal_test_case.effect_modifier_configuration,
216215
)

causal_testing/specification/causal_dag.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st
255255
gam.add_edges_from(edges_to_add)
256256

257257
min_seps = list(list_all_min_sep(gam, "TREATMENT", "OUTCOME", set(treatments), set(outcomes)))
258-
# min_seps.remove(set(outcomes))
258+
if set(outcomes) in min_seps:
259+
min_seps.remove(set(outcomes))
259260
return min_seps
260261

261262
def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
@@ -278,6 +279,7 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
278279
:param outcomes: A list of strings representing outcomes.
279280
:return: A list of strings representing the minimal adjustment set.
280281
"""
282+
281283
# 1. Construct the proper back-door graph's ancestor moral graph
282284
proper_backdoor_graph = self.get_proper_backdoor_graph(treatments, outcomes)
283285
ancestor_proper_backdoor_graph = proper_backdoor_graph.get_ancestor_graph(treatments, outcomes)
@@ -316,6 +318,7 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
316318
for adj in minimum_adjustment_sets
317319
if self.constructive_backdoor_criterion(proper_backdoor_graph, treatments, outcomes, adj)
318320
]
321+
319322
return valid_minimum_adjustment_sets
320323

321324
def adjustment_set_is_minimal(self, treatments: list[str], outcomes: list[str], adjustment_set: set[str]) -> bool:
@@ -465,5 +468,28 @@ def depends_on_outputs(self, node: Node, scenario: Scenario) -> bool:
465468
return True
466469
return any([self.depends_on_outputs(n, scenario) for n in self.graph.predecessors(node)])
467470

471+
def identification(self, base_test_case):
472+
"""Identify and return the minimum adjustment set
473+
474+
:param base_test_case: A base test case instance containing the outcome_variable and the
475+
treatment_variable required for identification.
476+
:return minimal_adjustment_set: The smallest set of variables which can be adjusted for to obtain a causal
477+
estimate as opposed to a purely associational estimate.
478+
"""
479+
minimal_adjustment_sets = []
480+
if base_test_case.effect == "total":
481+
minimal_adjustment_sets = self.enumerate_minimal_adjustment_sets(
482+
[base_test_case.treatment_variable.name], [base_test_case.outcome_variable.name]
483+
)
484+
elif base_test_case.effect == "direct":
485+
minimal_adjustment_sets = self.direct_effect_adjustment_sets(
486+
[base_test_case.treatment_variable.name], [base_test_case.outcome_variable.name]
487+
)
488+
else:
489+
raise ValueError("Causal effect should be 'total' or 'direct'")
490+
491+
minimal_adjustment_set = min(minimal_adjustment_sets, key=len)
492+
return minimal_adjustment_set
493+
468494
def __str__(self):
469495
return f"Nodes: {self.graph.nodes}\nEdges: {self.graph.edges}"

0 commit comments

Comments
 (0)