Skip to content

Commit 1f4f695

Browse files
committed
adding vs module
Signed-off-by: Nathaniel <[email protected]>
1 parent 62c27ed commit 1f4f695

File tree

6 files changed

+698
-19
lines changed

6 files changed

+698
-19
lines changed

causalpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,5 @@
4141
"RegressionKink",
4242
"skl_models",
4343
"SyntheticControl",
44+
"variable_selection_priors",
4445
]

causalpy/experiments/instrumental_variable.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ class InstrumentalVariable(BaseExperiment):
5252
"eta": 2,
5353
"lkj_sd": 2,
5454
}
55+
:param vs_prior_type : str or None, default=None
56+
Type of variable selection prior: 'spike_and_slab', 'horseshoe', or None.
57+
If None, uses standard normal priors.
58+
:param vs_hyperparams : dict, optional
59+
Hyperparameters for variable selection priors. Only used if vs_prior_type
60+
is not None.
5561
5662
Example
5763
--------
@@ -99,6 +105,8 @@ def __init__(
99105
formula: str,
100106
model=None,
101107
priors=None,
108+
vs_prior_type=None,
109+
vs_hyperparams=None,
102110
**kwargs,
103111
):
104112
super().__init__(model=model)
@@ -108,6 +116,8 @@ def __init__(
108116
self.formula = formula
109117
self.instruments_formula = instruments_formula
110118
self.model = model
119+
self.vs_prior_type = (vs_prior_type,)
120+
self.vs_hyperparams = vs_hyperparams or {}
111121
self.input_validation()
112122

113123
y, X = dmatrices(formula, self.data)
@@ -139,7 +149,14 @@ def __init__(
139149
}
140150
self.priors = priors
141151
self.model.fit(
142-
X=self.X, Z=self.Z, y=self.y, t=self.t, coords=COORDS, priors=self.priors
152+
X=self.X,
153+
Z=self.Z,
154+
y=self.y,
155+
t=self.t,
156+
coords=COORDS,
157+
priors=self.priors,
158+
vs_prior_type=vs_prior_type,
159+
vs_hyperparams=vs_hyperparams,
143160
)
144161

145162
def input_validation(self):

causalpy/pymc_models.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pymc_extras.prior import Prior
2727

2828
from causalpy.utils import round_num
29+
from causalpy.variable_selection_priors import VariableSelectionPrior
2930

3031

3132
class PyMCModel(pm.Model):
@@ -604,7 +605,9 @@ class InstrumentalVariableRegression(PyMCModel):
604605
Inference data...
605606
"""
606607

607-
def build_model(self, X, Z, y, t, coords, priors):
608+
def build_model(
609+
self, X, Z, y, t, coords, priors, vs_prior_type=None, vs_hyperparams=None
610+
):
608611
"""Specify model with treatment regression and focal regression data and priors
609612
610613
:param X: A pandas dataframe used to predict our outcome y
@@ -618,23 +621,47 @@ def build_model(self, X, Z, y, t, coords, priors):
618621
sigmas of both regressions
619622
:code:`priors = {"mus": [0, 0], "sigmas": [1, 1],
620623
"eta": 2, "lkj_sd": 2}`
624+
:param vs_prior_type: An optional string. Can be "spike_and_slab"
625+
or "horseshoe" or "normal
626+
:param vs_hyperparams: An optional dictionary of priors for the
627+
variable selection hyperparameters
628+
621629
"""
622630

623631
# --- Priors ---
624632
with self:
625633
self.add_coords(coords)
626-
beta_t = pm.Normal(
627-
name="beta_t",
628-
mu=priors["mus"][0],
629-
sigma=priors["sigmas"][0],
630-
dims="instruments",
631-
)
632-
beta_z = pm.Normal(
633-
name="beta_z",
634-
mu=priors["mus"][1],
635-
sigma=priors["sigmas"][1],
636-
dims="covariates",
637-
)
634+
635+
# Create coefficient priors
636+
if vs_prior_type:
637+
# Use variable selection priors
638+
vs_prior_treatment = VariableSelectionPrior(
639+
vs_prior_type, vs_hyperparams
640+
)
641+
vs_prior_outcome = VariableSelectionPrior(vs_prior_type, vs_hyperparams)
642+
643+
beta_t = vs_prior_treatment.create_prior(
644+
name="beta_t", n_params=Z.shape[1], dims="instruments", X=Z
645+
)
646+
647+
beta_z = vs_prior_outcome.create_prior(
648+
name="beta_z", n_params=X.shape[1], dims="covariates", X=X
649+
)
650+
else:
651+
# Use standard normal priors
652+
beta_t = pm.Normal(
653+
name="beta_t",
654+
mu=priors["mus"][0],
655+
sigma=priors["sigmas"][0],
656+
dims="instruments",
657+
)
658+
beta_z = pm.Normal(
659+
name="beta_z",
660+
mu=priors["mus"][1],
661+
sigma=priors["sigmas"][1],
662+
dims="covariates",
663+
)
664+
638665
sd_dist = pm.Exponential.dist(priors["lkj_sd"], shape=2)
639666
chol, corr, sigmas = pm.LKJCholeskyCov(
640667
name="chol_cov",
@@ -689,7 +716,18 @@ def sample_predictive_distribution(self, ppc_sampler="jax"):
689716
)
690717
)
691718

692-
def fit(self, X, Z, y, t, coords, priors, ppc_sampler=None):
719+
def fit(
720+
self,
721+
X,
722+
Z,
723+
y,
724+
t,
725+
coords,
726+
priors,
727+
ppc_sampler=None,
728+
vs_prior_type=None,
729+
vs_hyperparams=None,
730+
):
693731
"""Draw samples from posterior distribution and potentially
694732
from the prior and posterior predictive distributions. The
695733
fit call can take values for the
@@ -703,7 +741,7 @@ def fit(self, X, Z, y, t, coords, priors, ppc_sampler=None):
703741
# sample_posterior_predictive() if provided in sample_kwargs.
704742
# Use JAX for ppc sampling of multivariate likelihood
705743

706-
self.build_model(X, Z, y, t, coords, priors)
744+
self.build_model(X, Z, y, t, coords, priors, vs_prior_type, vs_hyperparams)
707745
with self:
708746
self.idata = pm.sample(**self.sample_kwargs)
709747
self.sample_predictive_distribution(ppc_sampler=ppc_sampler)

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,38 @@ def test_iv_reg(mock_pymc_sample):
676676
result.get_plot_data()
677677

678678

679+
@pytest.mark.integration
680+
def test_iv_reg_vs_prior(mock_pymc_sample):
681+
df = cp.load_data("risk")
682+
instruments_formula = "risk ~ 1 + logmort0"
683+
formula = "loggdp ~ 1 + risk"
684+
instruments_data = df[["risk", "logmort0"]]
685+
data = df[["loggdp", "risk"]]
686+
687+
result = cp.InstrumentalVariable(
688+
instruments_data=instruments_data,
689+
data=data,
690+
instruments_formula=instruments_formula,
691+
formula=formula,
692+
model=cp.pymc_models.InstrumentalVariableRegression(
693+
sample_kwargs=sample_kwargs
694+
),
695+
vs_prior_type="spike_and_slab",
696+
vs_hyperparams={"pi_alpha": 5},
697+
)
698+
result.model.sample_predictive_distribution(ppc_sampler="pymc")
699+
assert isinstance(df, pd.DataFrame)
700+
assert isinstance(data, pd.DataFrame)
701+
assert isinstance(instruments_data, pd.DataFrame)
702+
assert isinstance(result, cp.InstrumentalVariable)
703+
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
704+
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
705+
with pytest.raises(NotImplementedError):
706+
result.get_plot_data()
707+
assert "gamma_beta_t" in result.model.named_vars
708+
assert "pi_beta_t" in result.model.named_vars
709+
710+
679711
@pytest.mark.integration
680712
def test_inverse_prop(mock_pymc_sample):
681713
"""Test the InversePropensityWeighting class."""

0 commit comments

Comments
 (0)