Skip to content

Commit 4b14212

Browse files
committed
Tests passing again
1 parent a15e15b commit 4b14212

File tree

12 files changed

+90
-755
lines changed

12 files changed

+90
-755
lines changed

causal_testing/json_front/json_class.py

Lines changed: 9 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@
1111
from statistics import StatisticsError
1212

1313
import pandas as pd
14-
import scipy
15-
from fitter import Fitter, get_common_distributions
1614

17-
from causal_testing.generation.abstract_causal_test_case import AbstractCausalTestCase
1815
from causal_testing.specification.causal_dag import CausalDAG
1916
from causal_testing.specification.causal_specification import CausalSpecification
2017
from causal_testing.specification.scenario import Scenario
@@ -90,33 +87,6 @@ def setup(self, scenario: Scenario, ignore_cycles=False):
9087
)
9188
self._populate_metas()
9289

93-
def _create_abstract_test_case(self, test, mutates, effects):
94-
assert len(test["mutations"]) == 1
95-
treatment_var = next(self.scenario.variables[v] for v in test["mutations"])
96-
97-
if not treatment_var.distribution:
98-
fitter = Fitter(self.df[treatment_var.name], distributions=get_common_distributions())
99-
fitter.fit()
100-
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
101-
treatment_var.distribution = getattr(scipy.stats, dist)(**params)
102-
self._append_to_file(treatment_var.name + f" {dist}({params})", logging.INFO)
103-
104-
abstract_test = AbstractCausalTestCase(
105-
scenario=self.scenario,
106-
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
107-
treatment_variable=treatment_var,
108-
expected_causal_effect={
109-
self.scenario.variables[variable]: effects[effect]
110-
for variable, effect in test["expected_effect"].items()
111-
},
112-
effect_modifiers=(
113-
{self.scenario.variables[v] for v in test["effect_modifiers"]} if "effect_modifiers" in test else {}
114-
),
115-
estimate_type=test["estimate_type"],
116-
effect=test.get("effect", "total"),
117-
)
118-
return abstract_test
119-
12090
def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False, mutates: dict = None):
12191
"""Runs and evaluates each test case specified in the JSON input
12292
@@ -139,9 +109,7 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
139109
test=test, f_flag=f_flag, effects=effects, estimate_type=test["estimate_type"]
140110
)
141111
else:
142-
failed, msg = self._run_metamorphic_tests(
143-
test=test, f_flag=f_flag, effects=effects, mutates=mutates
144-
)
112+
raise NotImplementedError("Tried to call deprecated method _run_metamorphic_tests")
145113
test["failed"] = failed
146114
test["result"] = msg
147115
return self.test_plan["tests"]
@@ -189,8 +157,6 @@ def _run_concrete_metamorphic_test(self, test: dict, f_flag: bool, effects: dict
189157
causal_test_case = CausalTestCase(
190158
base_test_case=base_test_case,
191159
expected_causal_effect=effects[test["expected_effect"][outcome_variable]],
192-
control_value=test["control_value"],
193-
treatment_value=test["treatment_value"],
194160
estimate_type=test["estimate_type"],
195161
)
196162
failed, msg = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
@@ -205,41 +171,6 @@ def _run_concrete_metamorphic_test(self, test: dict, f_flag: bool, effects: dict
205171
self._append_to_file(msg, logging.INFO)
206172
return failed, msg
207173

208-
def _run_metamorphic_tests(self, test: dict, f_flag: bool, effects: dict, mutates: dict):
209-
"""Builds structures and runs test case for tests with an estimate_type of 'ate'.
210-
211-
:param test: Single JSON test definition stored in a mapping (dict)
212-
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
213-
:param effects: Dictionary mapping effect class instances to string representations.
214-
:param mutates: Dictionary mapping mutation functions to string representations.
215-
:return: String containing the message to be outputted
216-
"""
217-
if "sample_size" in test:
218-
sample_size = test["sample_size"]
219-
else:
220-
sample_size = 5
221-
if "target_ks_score" in test:
222-
target_ks_score = test["target_ks_score"]
223-
else:
224-
target_ks_score = 0.05
225-
abstract_test = self._create_abstract_test_case(test, mutates, effects)
226-
concrete_tests, _ = abstract_test.generate_concrete_tests(
227-
sample_size=sample_size, target_ks_score=target_ks_score
228-
)
229-
failures, _ = self._execute_tests(concrete_tests, test, f_flag)
230-
231-
msg = (
232-
f"Executing test: {test['name']} \n"
233-
+ " abstract_test \n"
234-
+ f" {abstract_test} \n"
235-
+ f" {abstract_test.treatment_variable.name},"
236-
+ f" {abstract_test.treatment_variable.distribution} \n"
237-
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
238-
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
239-
)
240-
self._append_to_file(msg, logging.INFO)
241-
return failures, msg
242-
243174
def _execute_tests(self, concrete_tests, test, f_flag):
244175
failures = 0
245176
details = []
@@ -274,7 +205,8 @@ def _execute_test_case(
274205
failed = False
275206

276207
estimation_model = self._setup_test(causal_test_case=causal_test_case, test=test)
277-
causal_test_result = causal_test_case.execute_test(estimator=estimation_model)
208+
causal_test_case.estimator = estimation_model
209+
causal_test_result = causal_test_case.execute_test()
278210
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
279211

280212
if "coverage" in test and test["coverage"]:
@@ -307,6 +239,7 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
307239
- estimation_model - Estimator instance for the test being run
308240
"""
309241
estimator_kwargs = {}
242+
treatment_variable = next(self.scenario.variables[v] for v in test["mutations"])
310243
if "formula" in test:
311244
if test["estimator"] != (LinearRegressionEstimator or LogisticRegressionEstimator):
312245
raise TypeError(
@@ -319,14 +252,14 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
319252
minimal_adjustment_set = self.causal_specification.causal_dag.identification(
320253
causal_test_case.base_test_case
321254
)
322-
minimal_adjustment_set = minimal_adjustment_set - {causal_test_case.treatment_variable}
255+
minimal_adjustment_set = minimal_adjustment_set - {treatment_variable}
323256
estimator_kwargs["adjustment_set"] = minimal_adjustment_set
324257

325258
estimator_kwargs["query"] = test["query"] if "query" in test else ""
326-
estimator_kwargs["treatment"] = causal_test_case.treatment_variable.name
327-
estimator_kwargs["treatment_value"] = causal_test_case.treatment_value
328-
estimator_kwargs["control_value"] = causal_test_case.control_value
329-
estimator_kwargs["outcome"] = causal_test_case.outcome_variable.name
259+
estimator_kwargs["treatment"] = treatment_variable.name
260+
estimator_kwargs["treatment_value"] = test.get("treatment_value")
261+
estimator_kwargs["control_value"] = test.get("control_value")
262+
estimator_kwargs["outcome"] = next(v for v in test["expected_effect"])
330263
estimator_kwargs["effect_modifiers"] = causal_test_case.effect_modifier_configuration
331264
estimator_kwargs["df"] = self.df
332265
estimator_kwargs["alpha"] = test["alpha"] if "alpha" in test else 0.05

causal_testing/testing/causal_test_case.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ def __init__(
2727
self,
2828
base_test_case: BaseTestCase,
2929
expected_causal_effect: CausalTestOutcome,
30-
control_value: Any = None,
31-
treatment_value: Any = None,
3230
estimate_type: str = "ate",
3331
estimate_params: dict = None,
3432
effect_modifier_configuration: dict[Variable:Any] = None,
@@ -37,18 +35,14 @@ def __init__(
3735
"""
3836
:param base_test_case: A BaseTestCase object consisting of a treatment variable, outcome variable and effect
3937
:param expected_causal_effect: The expected causal effect (Positive, Negative, No Effect).
40-
:param control_value: The control value for the treatment variable (before intervention).
41-
:param treatment_value: The treatment value for the treatment variable (after intervention).
4238
:param estimate_type: A string which denotes the type of estimate to return.
4339
:param effect_modifier_configuration: The assignment of the effect modifiers to use for estimates.
4440
:param estimator: An Estimator class object
4541
"""
4642
self.base_test_case = base_test_case
47-
self.control_value = control_value
4843
self.expected_causal_effect = expected_causal_effect
4944
self.outcome_variable = base_test_case.outcome_variable
5045
self.treatment_variable = base_test_case.treatment_variable
51-
self.treatment_value = treatment_value
5246
self.estimate_type = estimate_type
5347
self.estimator = estimator
5448
if estimate_params is None:
@@ -60,27 +54,34 @@ def __init__(
6054
else:
6155
self.effect_modifier_configuration = {}
6256

63-
def execute_test(self) -> CausalTestResult:
64-
"""Execute a causal test case and return the causal test result.
57+
def execute_test(self, estimator: type(Estimator) = None) -> CausalTestResult:
58+
"""
59+
Execute a causal test case and return the causal test result.
60+
:param estimator: An alternative estimator. Defaults to `self.estimator`. This parameter is useful when you want
61+
to execute a test with different data or a different equational form, but don't want to redefine the whole test
62+
case.
6563
6664
:return causal_test_result: A CausalTestResult for the executed causal test case.
6765
"""
6866

69-
if not hasattr(self.estimator, f"estimate_{self.estimate_type}"):
70-
raise AttributeError(f"{self.estimator.__class__} has no {self.estimate_type} method.")
71-
estimate_effect = getattr(self.estimator, f"estimate_{self.estimate_type}")
67+
if estimator is None:
68+
estimator = self.estimator
69+
70+
if not hasattr(estimator, f"estimate_{self.estimate_type}"):
71+
raise AttributeError(f"{estimator.__class__} has no {self.estimate_type} method.")
72+
estimate_effect = getattr(estimator, f"estimate_{self.estimate_type}")
7273
effect, confidence_intervals = estimate_effect(**self.estimate_params)
7374
return CausalTestResult(
74-
estimator=self.estimator,
75+
estimator=estimator,
7576
test_value=TestValue(self.estimate_type, effect),
7677
effect_modifier_configuration=self.effect_modifier_configuration,
7778
confidence_intervals=confidence_intervals,
7879
)
7980

8081
def __str__(self):
81-
treatment_config = {self.treatment_variable.name: self.treatment_value}
82-
control_config = {self.treatment_variable.name: self.control_value}
83-
outcome_variable = {self.outcome_variable}
82+
treatment_config = {self.treatment_variable.name: self.estimator.treatment_value}
83+
control_config = {self.treatment_variable.name: self.estimator.control_value}
84+
outcome_variable = {self.outcome_variable.name}
8485
return (
8586
f"Running {treatment_config} instead of {control_config} should cause the following "
8687
f"changes to {outcome_variable}: {self.expected_causal_effect}."

examples/covasim_/doubling_beta/example_beta.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -69,33 +69,37 @@ def doubling_beta_CATE_on_csv(
6969

7070
# 6. Create a causal test case
7171
causal_test_case = CausalTestCase(
72-
base_test_case=base_test_case, expected_causal_effect=Positive, control_value=0.016, treatment_value=0.032
73-
)
74-
75-
linear_regression_estimator = LinearRegressionEstimator(
76-
"beta",
77-
0.032,
78-
0.016,
79-
{"avg_age", "contacts"}, # We use custom adjustment set
80-
"cum_infections",
81-
df=past_execution_df,
82-
formula="cum_infections ~ beta + I(beta ** 2) + avg_age + contacts",
72+
base_test_case=base_test_case,
73+
expected_causal_effect=Positive,
74+
estimator=LinearRegressionEstimator(
75+
"beta",
76+
0.032,
77+
0.016,
78+
{"avg_age", "contacts"}, # We use custom adjustment set
79+
"cum_infections",
80+
df=past_execution_df,
81+
formula="cum_infections ~ beta + I(beta ** 2) + avg_age + contacts",
82+
),
8383
)
8484

8585
# Add squared terms for beta, since it has a quadratic relationship with cumulative infections
86-
causal_test_result = causal_test_case.execute_test(estimator=linear_regression_estimator)
86+
causal_test_result = causal_test_case.execute_test()
8787

8888
# Repeat for association estimate (no adjustment)
89-
no_adjustment_linear_regression_estimator = LinearRegressionEstimator(
90-
"beta",
91-
0.032,
92-
0.016,
93-
set(),
94-
"cum_infections",
95-
df=past_execution_df,
96-
formula="cum_infections ~ beta + I(beta ** 2)",
89+
causal_test_case = CausalTestCase(
90+
base_test_case=base_test_case,
91+
expected_causal_effect=Positive,
92+
estimator=LinearRegressionEstimator(
93+
"beta",
94+
0.032,
95+
0.016,
96+
set(),
97+
"cum_infections",
98+
df=past_execution_df,
99+
formula="cum_infections ~ beta + I(beta ** 2)",
100+
),
97101
)
98-
association_test_result = causal_test_case.execute_test(estimator=no_adjustment_linear_regression_estimator)
102+
association_test_result = causal_test_case.execute_test()
99103

100104
# Store results for plotting
101105
results_dict["association"] = {
@@ -116,7 +120,7 @@ def doubling_beta_CATE_on_csv(
116120
# Repeat causal inference after deleting all rows with treatment value to obtain counterfactual inferences
117121
if simulate_counterfactuals:
118122
counterfactual_past_execution_df = past_execution_df[past_execution_df["beta"] != 0.032]
119-
counterfactual_causal_test_result = causal_test_case.execute_test(estimator=linear_regression_estimator)
123+
counterfactual_causal_test_result = causal_test_case.execute_test()
120124

121125
results_dict["counterfactual"] = {
122126
"ate": counterfactual_causal_test_result.test_value.value,

examples/lr91/example_max_conductances.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -129,23 +129,18 @@ def effects_on_APD90(observational_data_path, treatment_var, control_val, treatm
129129
causal_test_case = CausalTestCase(
130130
base_test_case=base_test_case,
131131
expected_causal_effect=expected_causal_effect,
132-
control_value=control_val,
133-
treatment_value=treatment_val,
134-
)
135-
136-
# 8. Obtain the minimal adjustment set from the causal DAG
137-
minimal_adjustment_set = causal_dag.identification(base_test_case)
138-
linear_regression_estimator = LinearRegressionEstimator(
139-
treatment_var.name,
140-
treatment_val,
141-
control_val,
142-
minimal_adjustment_set,
143-
"APD90",
144-
df=pd.read_csv(observational_data_path),
132+
estimator=LinearRegressionEstimator(
133+
treatment=treatment_var.name,
134+
treatment_value=treatment_val,
135+
control_value=control_val,
136+
adjustment_set=causal_dag.identification(base_test_case),
137+
outcome="APD90",
138+
df=pd.read_csv(observational_data_path),
139+
),
145140
)
146141

147142
# 9. Run the causal test and print results
148-
causal_test_result = causal_test_case.execute_test(linear_regression_estimator)
143+
causal_test_result = causal_test_case.execute_test()
149144
logger.info("%s", causal_test_result)
150145
return causal_test_result.test_value.value, causal_test_result.confidence_intervals
151146

0 commit comments

Comments
 (0)