Skip to content

Commit 132a0fe

Browse files
committed
Coefficient causal effect metric (and black)
1 parent f1bf65a commit 132a0fe

File tree

9 files changed

+111
-56
lines changed

9 files changed

+111
-56
lines changed

causal_testing/json_front/json_class.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from causal_testing.testing.causal_test_case import CausalTestCase
2323
from causal_testing.testing.causal_test_engine import CausalTestEngine
2424
from causal_testing.testing.estimators import Estimator
25+
from causal_testing.testing.base_test_case import BaseTestCase
2526

2627
logger = logging.getLogger(__name__)
2728

@@ -110,20 +111,43 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
110111
for test in self.test_plan["tests"]:
111112
if "skip" in test and test["skip"]:
112113
continue
113-
abstract_test = self._create_abstract_test_case(test, mutates, effects)
114114

115-
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
116-
logger.info("Executing test: %s", test["name"])
117-
logger.info(abstract_test)
118-
logger.info([abstract_test.treatment_variable.name, abstract_test.treatment_variable.distribution])
119-
logger.info("Number of concrete tests for test case: %s", str(len(concrete_tests)))
115+
if test["estimate_type"] == "coefficient":
116+
base_test_case = BaseTestCase(
117+
treatment_variable=next(self.modelling_scenario.variables[v] for v in test["mutations"]),
118+
outcome_variable=next(self.modelling_scenario.variables[v] for v in test["expectedEffect"]),
119+
effect=test["effect"],
120+
)
121+
assert len(test["expectedEffect"]) == 1, "Can only have one expected effect."
122+
concrete_tests = [
123+
CausalTestCase(
124+
base_test_case=base_test_case,
125+
expected_causal_effect=next(
126+
effects[effect] for variable, effect in test["expectedEffect"].items()
127+
),
128+
estimate_type="coefficient",
129+
effect_modifier_configuration={
130+
self.modelling_scenario.variables[v] for v in test.get("effect_modifiers", [])
131+
},
132+
)
133+
]
134+
else:
135+
abstract_test = self._create_abstract_test_case(test, mutates, effects)
136+
137+
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
138+
logger.info("Executing test: %s", test["name"])
139+
logger.info(abstract_test)
140+
logger.info([abstract_test.treatment_variable.name, abstract_test.treatment_variable.distribution])
141+
logger.info("Number of concrete tests for test case: %s", str(len(concrete_tests)))
120142
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
121143
logger.info("%s/%s failed for %s\n", failures, len(concrete_tests), test["name"])
122144

123145
def _execute_tests(self, concrete_tests, estimators, test, f_flag):
124146
failures = 0
125147
for concrete_test in concrete_tests:
126-
failed = self._execute_test_case(concrete_test, estimators[test["estimator"]], f_flag)
148+
failed = self._execute_test_case(
149+
concrete_test, estimators[test["estimator"]], f_flag, test.get("conditions", [])
150+
)
127151
if failed:
128152
failures += 1
129153
return failures
@@ -152,7 +176,9 @@ def _populate_metas(self):
152176
var.distribution = getattr(scipy.stats, dist)(**params)
153177
logger.info(var.name + f" {dist}({params})")
154178

155-
def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estimator, f_flag: bool) -> bool:
179+
def _execute_test_case(
180+
self, causal_test_case: CausalTestCase, estimator: Estimator, f_flag: bool, conditions: list[str]
181+
) -> bool:
156182
"""Executes a singular test case, prints the results and returns the test case result
157183
:param causal_test_case: The concrete test case to be executed
158184
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
@@ -162,7 +188,8 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
162188
"""
163189
failed = False
164190

165-
causal_test_engine, estimation_model = self._setup_test(causal_test_case, estimator)
191+
print(causal_test_case)
192+
causal_test_engine, estimation_model = self._setup_test(causal_test_case, estimator, conditions)
166193
causal_test_result = causal_test_engine.execute_test(
167194
estimation_model, causal_test_case, estimate_type=causal_test_case.estimate_type
168195
)
@@ -187,15 +214,17 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
187214
logger.warning(" FAILED- expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
188215
return failed
189216

190-
def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) -> tuple[CausalTestEngine, Estimator]:
217+
def _setup_test(
218+
self, causal_test_case: CausalTestCase, estimator: Estimator, conditions: list[str]
219+
) -> tuple[CausalTestEngine, Estimator]:
191220
"""Create the necessary inputs for a single test case
192221
:param causal_test_case: The concrete test case to be executed
193222
:returns:
194223
- causal_test_engine - Test Engine instance for the test being run
195224
- estimation_model - Estimator instance for the test being run
196225
"""
197226

198-
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data)
227+
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data.query(" & ".join(conditions)))
199228
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
200229

201230
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)

causal_testing/testing/causal_test_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(
2727
self,
2828
base_test_case: BaseTestCase,
2929
expected_causal_effect: CausalTestOutcome,
30-
control_value: Any,
30+
control_value: Any = None,
3131
treatment_value: Any = None,
3232
estimate_type: str = "ate",
3333
effect_modifier_configuration: dict[Variable:Any] = None,

causal_testing/testing/causal_test_engine.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,15 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
174174
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
175175
confidence_intervals=confidence_intervals,
176176
)
177+
elif estimate_type == "coefficient":
178+
logger.debug("calculating coefficient")
179+
coefficient, confidence_intervals = estimator.estimate_unit_ate()
180+
causal_test_result = CausalTestResult(
181+
estimator=estimator,
182+
test_value=TestValue("coefficient", coefficient),
183+
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
184+
confidence_intervals=confidence_intervals,
185+
)
177186
elif estimate_type == "ate":
178187
logger.debug("calculating ate")
179188
ate, confidence_intervals = estimator.estimate_ate()

causal_testing/testing/causal_test_outcome.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class SomeEffect(CausalTestOutcome):
2626
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
2727

2828
def apply(self, res: CausalTestResult) -> bool:
29-
if res.test_value.type == "ate":
29+
if res.test_value.type == "ate" or res.test_value.type == "coefficient":
3030
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
3131
if res.test_value.type == "risk_ratio":
3232
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)

causal_testing/testing/estimators.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,10 +343,11 @@ def estimate_unit_ate(self) -> float:
343343
:return: The unit average treatment effect and the 95% Wald confidence intervals.
344344
"""
345345
model = self._run_linear_regression()
346+
assert self.treatment in model.params, f"{self.treatment} not in {model.params}"
346347
unit_effect = model.params[[self.treatment]].values[0] # Unit effect is the coefficient of the treatment
347348
[ci_low, ci_high] = self._get_confidence_intervals(model)
348349

349-
return unit_effect * self.treatment_value - unit_effect * self.control_value, [ci_low, ci_high]
350+
return unit_effect, [ci_low, ci_high]
350351

351352
def estimate_ate(self) -> tuple[float, list[float, float], float]:
352353
"""Estimate the average treatment effect of the treatment on the outcome. That is, the change in outcome caused

examples/covasim_/doubling_beta/example_beta.py

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -48,54 +48,75 @@ def doubling_beta_CATE_on_csv(
4848
:return results_dict: A nested dictionary containing results (ate and confidence intervals)
4949
for association, causation, and counterfactual (if completed).
5050
"""
51-
results_dict = {'association': {},
52-
'causation': {}}
51+
results_dict = {"association": {}, "causation": {}}
5352

5453
# Read in the observational data, perform identification, and setup the causal_test_engine
5554
past_execution_df = pd.read_csv(observational_data_path)
5655
_, causal_test_engine, causal_test_case = engine_setup(observational_data_path)
5756

58-
linear_regression_estimator = LinearRegressionEstimator('beta', 0.032, 0.016,
59-
{'avg_age', 'contacts'}, # We use custom adjustment set
60-
'cum_infections',
61-
df=past_execution_df,
62-
formula="cum_infections ~ beta + np.power(beta, 2) + avg_age + contacts")
57+
linear_regression_estimator = LinearRegressionEstimator(
58+
"beta",
59+
0.032,
60+
0.016,
61+
{"avg_age", "contacts"}, # We use custom adjustment set
62+
"cum_infections",
63+
df=past_execution_df,
64+
formula="cum_infections ~ beta + np.power(beta, 2) + avg_age + contacts",
65+
)
6366

6467
# Add squared terms for beta, since it has a quadratic relationship with cumulative infections
65-
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, 'ate')
68+
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, "ate")
6669

6770
# Repeat for association estimate (no adjustment)
68-
no_adjustment_linear_regression_estimator = LinearRegressionEstimator('beta', 0.032, 0.016,
69-
set(),
70-
'cum_infections',
71-
df=past_execution_df,
72-
formula="cum_infections ~ beta + np.power(beta, 2)")
73-
association_test_result = causal_test_engine.execute_test(no_adjustment_linear_regression_estimator, causal_test_case, 'ate')
71+
no_adjustment_linear_regression_estimator = LinearRegressionEstimator(
72+
"beta",
73+
0.032,
74+
0.016,
75+
set(),
76+
"cum_infections",
77+
df=past_execution_df,
78+
formula="cum_infections ~ beta + np.power(beta, 2)",
79+
)
80+
association_test_result = causal_test_engine.execute_test(
81+
no_adjustment_linear_regression_estimator, causal_test_case, "ate"
82+
)
7483

7584
# Store results for plotting
76-
results_dict['association'] = {'ate': association_test_result.test_value.value,
77-
'cis': association_test_result.confidence_intervals,
78-
'df': past_execution_df}
79-
results_dict['causation'] = {'ate': causal_test_result.test_value.value,
80-
'cis': causal_test_result.confidence_intervals,
81-
'df': past_execution_df}
85+
results_dict["association"] = {
86+
"ate": association_test_result.test_value.value,
87+
"cis": association_test_result.confidence_intervals,
88+
"df": past_execution_df,
89+
}
90+
results_dict["causation"] = {
91+
"ate": causal_test_result.test_value.value,
92+
"cis": causal_test_result.confidence_intervals,
93+
"df": past_execution_df,
94+
}
8295

8396
if verbose:
8497
print(f"Association:\n{association_test_result}")
8598
print(f"Causation:\n{causal_test_result}")
8699

87100
# Repeat causal inference after deleting all rows with treatment value to obtain counterfactual inferences
88101
if simulate_counterfactuals:
89-
counterfactual_past_execution_df = past_execution_df[past_execution_df['beta'] != 0.032]
90-
counterfactual_linear_regression_estimator = LinearRegressionEstimator('beta', 0.032, 0.016,
91-
{'avg_age', 'contacts'},
92-
'cum_infections',
93-
df=counterfactual_past_execution_df,
94-
formula="cum_infections ~ beta + np.power(beta, 2) + avg_age + contacts")
95-
counterfactual_causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, 'ate')
96-
results_dict['counterfactual'] = {'ate': counterfactual_causal_test_result.test_value.value,
97-
'cis': counterfactual_causal_test_result.confidence_intervals,
98-
'df': counterfactual_past_execution_df}
102+
counterfactual_past_execution_df = past_execution_df[past_execution_df["beta"] != 0.032]
103+
counterfactual_linear_regression_estimator = LinearRegressionEstimator(
104+
"beta",
105+
0.032,
106+
0.016,
107+
{"avg_age", "contacts"},
108+
"cum_infections",
109+
df=counterfactual_past_execution_df,
110+
formula="cum_infections ~ beta + np.power(beta, 2) + avg_age + contacts",
111+
)
112+
counterfactual_causal_test_result = causal_test_engine.execute_test(
113+
linear_regression_estimator, causal_test_case, "ate"
114+
)
115+
results_dict["counterfactual"] = {
116+
"ate": counterfactual_causal_test_result.test_value.value,
117+
"cis": counterfactual_causal_test_result.confidence_intervals,
118+
"df": counterfactual_past_execution_df,
119+
}
99120
if verbose:
100121
print(f"Counterfactual:\n{counterfactual_causal_test_result}")
101122

examples/poisson-line-process/example_poisson_process.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,7 @@ def causal_test_intensity_num_shapes(
8686
data_collector = ObservationalDataCollector(scenario, pd.read_csv(observational_data_path))
8787

8888
# 7. Create an instance of the causal test engine
89-
causal_test_engine = CausalTestEngine(
90-
causal_specification, data_collector
91-
)
89+
causal_test_engine = CausalTestEngine(causal_specification, data_collector)
9290

9391
# 8. Obtain the minimal adjustment set for the causal test case from the causal DAG
9492
minimal_adjustment_set = causal_dag.identification(causal_test_case.base_test_case)
@@ -121,13 +119,11 @@ def causal_test_intensity_num_shapes(
121119
outcome=outcome,
122120
df=data,
123121
effect_modifiers=causal_test_case.effect_modifier_configuration,
124-
formula=f"{outcome} ~ {treatment} + {'+'.join(square_terms + inverse_terms + list([e for e in causal_test_case.effect_modifier_configuration]))} -1"
122+
formula=f"{outcome} ~ {treatment} + {'+'.join(square_terms + inverse_terms + list([e for e in causal_test_case.effect_modifier_configuration]))} -1",
125123
)
126124

127125
# 10. Execute the test
128-
causal_test_result = causal_test_engine.execute_test(
129-
estimator, causal_test_case, causal_test_case.estimate_type
130-
)
126+
causal_test_result = causal_test_engine.execute_test(estimator, causal_test_case, causal_test_case.estimate_type)
131127

132128
return causal_test_result
133129

tests/specification_tests/test_metamorphic_relations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@ def test_should_cause_metamorphic_relations_correct_spec_one_input(self):
106106
self.data_collector = SingleInputProgramUnderTestEDC(
107107
self.scenario, self.default_control_input_config, self.default_treatment_input_config
108108
)
109-
causal_dag.graph.remove_nodes_from(['X2', 'X3'])
110-
adj_set = list(causal_dag.direct_effect_adjustment_sets(['X1'], ['Z'])[0])
111-
should_cause_MR = ShouldCause('X1', 'Z', adj_set, causal_dag)
109+
causal_dag.graph.remove_nodes_from(["X2", "X3"])
110+
adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0])
111+
should_cause_MR = ShouldCause("X1", "Z", adj_set, causal_dag)
112112
should_cause_MR.generate_follow_up(10, -10.0, 10.0, 1)
113113
test_results = should_cause_MR.execute_tests(self.data_collector)
114114
should_cause_MR.test_oracle(test_results)

tests/specification_tests/test_variable.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ class Var(Variable):
133133
var = Var("v", int)
134134
self.assertEqual(var.typestring(), "Var")
135135

136-
137136
def test_copy(self):
138137
ip = Input("ip", float, norm)
139138
self.assertTrue(ip.copy() is not ip)
@@ -173,7 +172,7 @@ def test_neg(self):
173172
self.assertEqual(str(-self.i1), "-i1")
174173

175174
def test_pow(self):
176-
self.assertEqual(str(self.i1 ** 5), "i1**5")
175+
self.assertEqual(str(self.i1**5), "i1**5")
177176

178177
def test_le(self):
179178
self.assertEqual(str(self.i1 <= 5), "i1 <= 5")

0 commit comments

Comments
 (0)