Skip to content

Commit 8d6251f

Browse files
committed
update adding more tests
Signed-off-by: Nathaniel <[email protected]>
1 parent 73e6a8d commit 8d6251f

File tree

5 files changed

+798
-788
lines changed

5 files changed

+798
-788
lines changed

causalpy/pymc_models.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -684,9 +684,9 @@ def build_model( # type: ignore
684684
Dictionary of priors for the mus and sigmas of both
685685
regressions. Example: ``priors = {"mus": [0, 0],
686686
"sigmas": [1, 1], "eta": 2, "lkj_sd": 2}``.
687-
:param vs_prior_type: An optional string. Can be "spike_and_slab"
687+
vs_prior_type: An optional string. Can be "spike_and_slab"
688688
or "horseshoe" or "normal
689-
:param vs_hyperparams: An optional dictionary of priors for the
689+
vs_hyperparams: An optional dictionary of priors for the
690690
variable selection hyperparameters
691691
692692
"""
@@ -705,16 +705,18 @@ def build_model( # type: ignore
705705
# Create coefficient priors
706706
if vs_prior_type:
707707
# Use variable selection priors
708-
vs_prior_treatment = VariableSelectionPrior(
708+
self.vs_prior_treatment = VariableSelectionPrior(
709+
vs_prior_type, vs_hyperparams
710+
)
711+
self.vs_prior_outcome = VariableSelectionPrior(
709712
vs_prior_type, vs_hyperparams
710713
)
711-
vs_prior_outcome = VariableSelectionPrior(vs_prior_type, vs_hyperparams)
712714

713-
beta_t = vs_prior_treatment.create_prior(
715+
beta_t = self.vs_prior_treatment.create_prior(
714716
name="beta_t", n_params=Z.shape[1], dims="instruments", X=Z
715717
)
716718

717-
beta_z = vs_prior_outcome.create_prior(
719+
beta_z = self.vs_prior_outcome.create_prior(
718720
name="beta_z", n_params=X.shape[1], dims="covariates", X=X
719721
)
720722
else:
@@ -733,7 +735,7 @@ def build_model( # type: ignore
733735
)
734736

735737
sd_dist = pm.Exponential.dist(priors["lkj_sd"], shape=2)
736-
chol, corr, sigmas = pm.LKJCholeskyCov(
738+
chol, _, _ = pm.LKJCholeskyCov(
737739
name="chol_cov",
738740
eta=priors["eta"],
739741
n=2,

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,45 @@ def test_iv_reg_vs_prior(mock_pymc_sample):
706706
result.get_plot_data()
707707
assert "gamma_beta_t" in result.model.named_vars
708708
assert "pi_beta_t" in result.model.named_vars
709+
summary = result.model.vs_prior_outcome.get_inclusion_probabilities(
710+
result.idata, "beta_z"
711+
)
712+
assert isinstance(summary, pd.DataFrame)
713+
714+
715+
@pytest.mark.integration
716+
def test_iv_reg_vs_prior_hs(mock_pymc_sample):
717+
df = cp.load_data("risk")
718+
instruments_formula = "risk ~ 1 + logmort0"
719+
formula = "loggdp ~ 1 + risk"
720+
instruments_data = df[["risk", "logmort0"]]
721+
data = df[["loggdp", "risk"]]
722+
723+
result = cp.InstrumentalVariable(
724+
instruments_data=instruments_data,
725+
data=data,
726+
instruments_formula=instruments_formula,
727+
formula=formula,
728+
model=cp.pymc_models.InstrumentalVariableRegression(
729+
sample_kwargs=sample_kwargs
730+
),
731+
vs_prior_type="horseshoe",
732+
)
733+
result.model.sample_predictive_distribution(ppc_sampler="pymc")
734+
assert isinstance(df, pd.DataFrame)
735+
assert isinstance(data, pd.DataFrame)
736+
assert isinstance(instruments_data, pd.DataFrame)
737+
assert isinstance(result, cp.InstrumentalVariable)
738+
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
739+
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
740+
with pytest.raises(NotImplementedError):
741+
result.get_plot_data()
742+
assert "tau_beta_t" in result.model.named_vars
743+
assert "tau_beta_z" in result.model.named_vars
744+
summary = result.model.vs_prior_outcome.get_shrinkage_factors(
745+
result.idata, "beta_z"
746+
)
747+
assert isinstance(summary, pd.DataFrame)
709748

710749

711750
@pytest.mark.integration

causalpy/variable_selection_priors.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from typing import Any, Dict, Optional, Union
2424

2525
import numpy as np
26+
import pandas as pd
2627
import pymc as pm
2728
import pytensor.tensor as pt
2829
from pymc_extras.prior import Prior
@@ -65,9 +66,10 @@ class SpikeAndSlabPrior:
6566
Creates a mixture prior with a point mass at zero (spike) and a diffuse
6667
normal distribution (slab), implemented as:
6768
68-
β_j = γ_j × β_j^raw
69-
70-
where γ_j ∈ [0,1] is a relaxed indicator and β_j^raw ~ N(0, σ_slab²).
69+
.. math::
70+
\beta_{j} = \gamma_{j} \cdot \beta_{j}^{\text{raw}} \\
71+
\beta_{j}^{\text{raw}} \sim \mathcal{N}(0, \sigma_{\text{slab}}^{2}), \qquad
72+
\gamma_{j} \in [0,1].
7173
7274
Parameters
7375
----------
@@ -145,9 +147,9 @@ class HorseshoePrior:
145147
Provides continuous shrinkage with heavy tails, allowing strong signals
146148
to escape shrinkage while weak signals are dampened:
147149
148-
β_j = τ · λ̃_j · β_j^raw
149-
150-
where λ̃_j = √(c²λ_j² / (c² + τ²λ_j²)) is the regularized local shrinkage.
150+
.. math::
151+
\beta_{j} & = \tau \cdot \lambda_{j} \cdot \beta_{j}^{raw} \\
152+
\lambda_{j} & = \sqrt{ \dfrac{c^{2}\lambda_{j}^{2}}{c^{2} + \tau^{2}\lambda_{j}^{2}} }
151153
152154
Parameters
153155
----------
@@ -423,7 +425,7 @@ def create_prior(
423425

424426
def get_inclusion_probabilities(
425427
self, idata, param_name: str, threshold: float = 0.5
426-
) -> Dict[str, np.ndarray]:
428+
) -> pd.DataFrame:
427429
"""
428430
Extract variable inclusion probabilities from fitted model.
429431
@@ -472,17 +474,24 @@ def get_inclusion_probabilities(
472474
gamma = az.extract(idata.posterior[gamma_name])
473475

474476
# Compute inclusion probabilities
475-
probabilities = (gamma > threshold).mean(dim="sample").values
476-
gamma_mean = gamma.mean(dim="sample").values
477+
probabilities = (gamma > threshold).mean(dim="sample").to_array()
478+
gamma_mean = gamma.mean(dim="sample").to_array()
477479
selected = probabilities > threshold
478480

479-
return {
481+
summary = {
480482
"probabilities": probabilities,
481483
"selected": selected,
482484
"gamma_mean": gamma_mean,
483485
}
486+
probs = summary["probabilities"].T
487+
df = pd.DataFrame(index=list(range(len(probs))))
488+
489+
df["prob"] = probs
490+
df["selected"] = summary["selected"].T
491+
df["gamma_mean"] = summary["gamma_mean"].T
492+
return df
484493

485-
def get_shrinkage_factors(self, idata, param_name: str) -> Dict[str, np.ndarray]:
494+
def get_shrinkage_factors(self, idata, param_name: str) -> pd.DataFrame:
486495
"""
487496
Extract shrinkage factors from horseshoe prior.
488497
@@ -524,17 +533,26 @@ def get_shrinkage_factors(self, idata, param_name: str) -> Dict[str, np.ndarray]
524533
raise ValueError(f"Could not find '{lambda_tilde_name}' in posterior")
525534

526535
# Extract components
527-
tau = az.extract(idata.posterior[tau_name])
528-
lambda_tilde = az.extract(idata.posterior[lambda_tilde_name])
536+
tau = az.extract(idata.posterior[tau_name]).to_array()
537+
lambda_tilde = az.extract(idata.posterior[lambda_tilde_name]).to_array()
529538

530-
# Compute shrinkage factors
531-
shrinkage_factors = (tau * lambda_tilde).mean(dim="sample").values
539+
shrinkage_factor = np.array(
540+
[tau[0, i] * lambda_tilde[0, :, :] for i in range(len(tau))]
541+
)
542+
shrinkage_factor = shrinkage_factor.mean(axis=2)
532543

533-
return {
534-
"shrinkage_factors": shrinkage_factors,
535-
"tau": tau.mean().values,
536-
"lambda_tilde": lambda_tilde.mean(dim="sample").values,
544+
summary = {
545+
"shrinkage_factors": shrinkage_factor,
546+
"tau": tau.mean(),
547+
"lambda_tilde": lambda_tilde.mean(dim=("sample")),
537548
}
549+
probs = summary["shrinkage_factors"].T
550+
df = pd.DataFrame(index=list(range(len(probs))))
551+
df["shrinkage_factor"] = probs
552+
553+
df["lambda_tilde"] = summary["lambda_tilde"].T
554+
df["tau"] = np.mean(tau).item()
555+
return df
538556

539557

540558
def create_variable_selection_prior(

docs/source/_static/interrogate_badge.svg

Lines changed: 4 additions & 4 deletions
Loading

0 commit comments

Comments
 (0)