Skip to content

Commit 5dbc15e

Browse files
committed
add test coverage around weighting
Signed-off-by: Nathaniel <[email protected]>
1 parent 6323c11 commit 5dbc15e

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

causalpy/tests/test_pymc_experiments.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
"""
44

55
import causalpy as cp
6+
import arviz as az
7+
import pandas as pd
68

79
sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2}
810

@@ -42,3 +44,34 @@ def test_regression_kink_gradient_change():
4244
cp.pymc_experiments.RegressionKink._eval_gradient_change(-1, -1, -2, 1) == -1.0
4345
)
4446
assert cp.pymc_experiments.RegressionKink._eval_gradient_change(1, 0, -2, 1) == -1.0
47+
48+
49+
def test_inverse_prop_param_recovery():
50+
df = cp.load_data("nhefs")
51+
seed = 42
52+
result = cp.pymc_experiments.InversePropensityWeighting(
53+
df,
54+
formula="trt ~ 1 + age + race",
55+
outcome_variable ="outcome",
56+
weighting_scheme="robust",
57+
model=cp.pymc_models.PropensityScore(
58+
sample_kwargs=sample_kwargs
59+
),
60+
)
61+
assert isinstance(result.idata, az.InferenceData)
62+
ps = result.idata.posterior['p'].mean(dim=('chain', 'draw'))
63+
w1, w2, _, _ = result.make_doubly_robust_adjustment(ps)
64+
assert isinstance(w1, pd.Series)
65+
assert isinstance(w2, pd.Series)
66+
w1, w2, n1, nw = result.make_raw_adjustments(ps)
67+
assert isinstance(w1, pd.Series)
68+
assert isinstance(w2, pd.Series)
69+
w1, w2, n1, n2 = result.make_robust_adjustments(ps)
70+
assert isinstance(w1, pd.Series)
71+
assert isinstance(w2, pd.Series)
72+
w1, w2, n1, n2 = result.make_overlap_adjustments(ps)
73+
assert isinstance(w1, pd.Series)
74+
assert isinstance(w2, pd.Series)
75+
76+
77+

0 commit comments

Comments
 (0)