Skip to content

Commit db1e81d

Browse files
committed
adding demo notebook
Signed-off-by: Nathaniel <[email protected]>
1 parent 2a01933 commit db1e81d

File tree

3 files changed

+1397
-2
lines changed

3 files changed

+1397
-2
lines changed

causalpy/experiments/instrumental_variable.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,16 @@ class InstrumentalVariable(BaseExperiment):
9191
... formula=formula,
9292
... model=InstrumentalVariableRegression(sample_kwargs=sample_kwargs),
9393
... )
94+
>>> # With variable selection
95+
>>> iv = cp.InstrumentalVariable(
96+
... instruments_data=instruments_data,
97+
... data=data,
98+
... instruments_formula=instruments_formula,
99+
... formula=formula,
100+
... model=InstrumentalVariableRegression(sample_kwargs=sample_kwargs),
101+
... vs_prior_type="spike_and_slab",
102+
... vs_hyperparams={"slab_sigma": 5.0},
103+
... )
94104
"""
95105

96106
supports_ols = False
@@ -115,7 +125,7 @@ def __init__(
115125
self.formula = formula
116126
self.instruments_formula = instruments_formula
117127
self.model = model
118-
self.vs_prior_type = (vs_prior_type,)
128+
self.vs_prior_type = vs_prior_type
119129
self.vs_hyperparams = vs_hyperparams or {}
120130
self.input_validation()
121131

@@ -142,7 +152,7 @@ def __init__(
142152
if priors is None:
143153
priors = {
144154
"mus": [self.ols_beta_first_params, self.ols_beta_second_params],
145-
"sigmas": [1, 1],
155+
"sigmas": [10, 10],
146156
"eta": 2,
147157
"lkj_sd": 1,
148158
}

causalpy/pymc_models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Custom PyMC models for causal inference"""
1515

16+
import warnings
1617
from typing import Any, Dict
1718

1819
import arviz as az
@@ -694,6 +695,13 @@ def build_model( # type: ignore
694695
with self:
695696
self.add_coords(coords)
696697

698+
if vs_prior_type and ("mus" in priors or "sigmas" in priors):
699+
warnings.warn(
700+
"Variable selection priors specified. "
701+
"The 'mus' and 'sigmas' in the priors dict will be ignored "
702+
"for beta coefficients. Only 'eta' and 'lkj_sd' will be used."
703+
)
704+
697705
# Create coefficient priors
698706
if vs_prior_type:
699707
# Use variable selection priors

0 commit comments

Comments
 (0)