Skip to content

Commit 7d6ec17

Browse files
committed
Preliminary coverage
1 parent 8a4d42f commit 7d6ec17

File tree

5 files changed

+83
-36
lines changed

5 files changed

+83
-36
lines changed

causal_testing/json_front/json_class.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from causal_testing.testing.causal_test_engine import CausalTestEngine
2626
from causal_testing.testing.estimators import Estimator
2727
from causal_testing.testing.base_test_case import BaseTestCase
28+
from causal_testing.testing.causal_test_adequacy import DataAdequacy
2829

2930
logger = logging.getLogger(__name__)
3031

@@ -264,9 +265,28 @@ def _execute_test_case(
264265
causal_test_case, test, test["conditions"] if "conditions" in test else None
265266
)
266267
causal_test_result = causal_test_engine.execute_test(estimation_model, causal_test_case)
267-
268268
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
269269

270+
if "coverage" in test and test["coverage"]:
271+
adequacy = DataAdequacy(causal_test_case, causal_test_engine, estimation_model)
272+
results = adequacy.measure_adequacy_bootstrap(100)
273+
outcomes = [causal_test_case.expected_causal_effect.apply(c) for c in results]
274+
coverage = pd.DataFrame(c.to_dict() for c in results)[["effect_estimate", "ci_low", "ci_high"]]
275+
coverage["pass"] = outcomes
276+
std = coverage.std(numeric_only=True)
277+
self._append_to_file(f"COVERAGE: {coverage['pass'].sum()}", logging.INFO)
278+
# std["pass"] = coverage["pass"].sum()
279+
# print(coverage)
280+
# print(std)
281+
282+
# k_folds = adequacy.measure_adequacy_k_folds()
283+
284+
# import matplotlib.pyplot as plt
285+
#
286+
# plt.hist(coverage["ci_low"], alpha=0.8)
287+
# plt.hist(coverage["ci_high"], alpha=0.8)
288+
# plt.show()
289+
270290
if causal_test_result.ci_low() is not None and causal_test_result.ci_high() is not None:
271291
result_string = (
272292
f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < "

causal_testing/testing/causal_test_adequacy.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,13 @@
55
from causal_testing.specification.causal_specification import CausalSpecification
66
from causal_testing.testing.estimators import Estimator
77
from causal_testing.testing.causal_test_case import CausalTestCase
8+
from causal_testing.testing.causal_test_engine import CausalTestEngine
89
from itertools import combinations
9-
10-
11-
class CausalTestAdequacy:
12-
def __init__(
13-
self,
14-
causal_specification: CausalSpecification,
15-
test_suite: CausalTestSuite,
16-
):
17-
self.causal_dag, self.scenario = (
18-
causal_specification.causal_dag,
19-
causal_specification.scenario,
20-
)
21-
self.test_suite = test_suite
22-
self.estimator = estimator
23-
self.dag_adequacy = DAGAdequacy(causal_specification, test_suite)
10+
from copy import deepcopy
11+
from sklearn.model_selection import KFold
12+
from sklearn.metrics import mean_squared_error as mse
13+
import numpy as np
14+
from sklearn.model_selection import cross_val_score
2415

2516

2617
class DAGAdequacy:
@@ -31,9 +22,37 @@ def __init__(
3122
):
3223
self.causal_dag = causal_specification.causal_dag
3324
self.test_suite = test_suite
25+
26+
def measure_adequacy(self):
3427
self.tested_pairs = {
3528
(t.base_test_case.treatment_variable, t.base_test_case.outcome_variable) for t in self.causal_test_suite
3629
}
37-
self.pairs_to_test = set(combinations(dag.graph.nodes, 2))
30+
self.pairs_to_test = set(combinations(self.causal_dag.graph.nodes, 2))
3831
self.untested_edges = pairs_to_test.difference(tested_pairs)
3932
self.dag_adequacy = len(tested_pairs) / len(pairs_to_test)
33+
34+
35+
class DataAdequacy:
36+
def __init__(self, test_case: CausalTestCase, test_engine: CausalTestEngine, estimator: Estimator):
37+
self.test_case = test_case
38+
self.test_engine = test_engine
39+
self.estimator = estimator
40+
41+
def measure_adequacy_bootstrap(self, bootstrap_size: int = 100):
42+
results = []
43+
for i in range(bootstrap_size):
44+
estimator = deepcopy(self.estimator)
45+
estimator.df = estimator.df.sample(len(estimator.df), replace=True, random_state=i)
46+
results.append(self.test_engine.execute_test(estimator, self.test_case))
47+
return results
48+
49+
def measure_adequacy_k_folds(self, k: int = 10, random_state=0):
50+
results = []
51+
kf = KFold(n_splits=k, shuffle=True, random_state=random_state)
52+
for train_inx, test_inx in kf.split(self.estimator.df):
53+
estimator = deepcopy(self.estimator)
54+
test = estimator.df.iloc[test_inx]
55+
estimator.df = estimator.df.iloc[train_inx]
56+
test_result = estimator.model.predict(test)
57+
results.append(np.sqrt(mse(test_result, test[self.test_case.base_test_case.outcome_variable.name])).mean())
58+
print("K-score", np.mean(results))

causal_testing/testing/causal_test_outcome.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,13 @@ def apply(self, res: CausalTestResult) -> bool:
4141
class NoEffect(CausalTestOutcome):
4242
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
4343

44-
def __init__(self, atol: float = 1e-10):
44+
def __init__(self, atol: float = 1e-10, ctol: float = 0.05):
45+
"""
46+
:param atol: Arithmetic tolerance. The test will pass if the absolute value of the causal effect is less than atol.
47+
:param ctol: Categorical tolerance. The test will pass if this proportion of categories pass.
48+
"""
4549
self.atol = atol
50+
self.ctol = ctol
4651

4752
def apply(self, res: CausalTestResult) -> bool:
4853
if res.test_value.type == "ate":
@@ -52,14 +57,19 @@ def apply(self, res: CausalTestResult) -> bool:
5257
ci_high = res.ci_high() if isinstance(res.ci_high(), Iterable) else [res.ci_high()]
5358
value = res.test_value.value if isinstance(res.ci_high(), Iterable) else [res.test_value.value]
5459

55-
if not all(ci_low < 0 < ci_high for ci_low, ci_high in zip(ci_low, ci_high)):
56-
print(
57-
"FAILING ON",
58-
[(ci_low, ci_high) for ci_low, ci_high in zip(ci_low, ci_high) if not ci_low < 0 < ci_high],
59-
)
60+
# if not all(ci_low < 0 < ci_high for ci_low, ci_high in zip(ci_low, ci_high)):
61+
# print(
62+
# "FAILING ON",
63+
# [(ci_low, ci_high) for ci_low, ci_high in zip(ci_low, ci_high) if not ci_low < 0 < ci_high],
64+
# )
6065

61-
return all(ci_low < 0 < ci_high for ci_low, ci_high in zip(ci_low, ci_high)) or all(
62-
abs(v) < self.atol for v in value
66+
return (
67+
sum(
68+
not ((ci_low < 0 < ci_high) or abs(v) < self.atol)
69+
for ci_low, ci_high, v in zip(ci_low, ci_high, value)
70+
)
71+
/ len(value)
72+
< self.ctol
6373
)
6474
if res.test_value.type == "risk_ratio":
6575
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=self.atol)

causal_testing/testing/causal_test_result.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,17 @@ def to_dict(self):
7272
"""Return result contents as a dictionary
7373
:return: Dictionary containing contents of causal_test_result
7474
"""
75-
base_dict = {
76-
"treatment": self.estimator.treatment[0],
75+
return {
76+
"treatment": self.estimator.treatment,
7777
"control_value": self.estimator.control_value,
7878
"treatment_value": self.estimator.treatment_value,
79-
"outcome": self.estimator.outcome[0],
79+
"outcome": self.estimator.outcome,
8080
"adjustment_set": self.adjustment_set,
81-
"test_value": self.test_value,
81+
"effect_measure": self.test_value.type,
82+
"effect_estimate": self.test_value.value,
83+
"ci_low": self.ci_low(),
84+
"ci_high": self.ci_high(),
8285
}
83-
if self.confidence_intervals and all(self.confidence_intervals):
84-
base_dict["ci_low"] = min(self.confidence_intervals)
85-
base_dict["ci_high"] = max(self.confidence_intervals)
86-
return base_dict
8786

8887
def ci_low(self):
8988
"""Return the lower bracket of the confidence intervals."""

causal_testing/testing/estimators.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def _run_logistic_regression(self, data) -> RegressionResultsWrapper:
150150
treatment_and_adjustments_cols, columns=[col], drop_first=True
151151
)
152152
model = smf.logit(formula=self.formula, data=data).fit(disp=0)
153+
self.model = model
153154
return model
154155

155156
def estimate(self, data: pd.DataFrame, adjustment_config: dict = None) -> RegressionResultsWrapper:
@@ -165,7 +166,6 @@ def estimate(self, data: pd.DataFrame, adjustment_config: dict = None) -> Regres
165166
)
166167

167168
model = self._run_logistic_regression(data)
168-
self.model = model
169169

170170
x = pd.DataFrame(columns=self.df.columns)
171171
x["Intercept"] = 1 # self.intercept
@@ -371,7 +371,6 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
371371
:return: The average treatment effect and the 95% Wald confidence intervals.
372372
"""
373373
model = self._run_linear_regression()
374-
self.model = model
375374

376375
# Create an empty individual for the control and treated
377376
individuals = pd.DataFrame(1, index=["control", "treated"], columns=model.params.index)
@@ -397,7 +396,6 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
397396
adjustment_config = {}
398397

399398
model = self._run_linear_regression()
400-
self.model = model
401399

402400
x = pd.DataFrame(columns=self.df.columns)
403401
x[self.treatment] = [self.treatment_value, self.control_value]
@@ -447,6 +445,7 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
447445
:return: The model after fitting to data.
448446
"""
449447
model = smf.ols(formula=self.formula, data=self.df).fit()
448+
self.model = model
450449
return model
451450

452451
def _get_confidence_intervals(self, model, treatment):

0 commit comments

Comments
 (0)