Skip to content

Commit 73e6a8d

Browse files
committed
adding tests
Signed-off-by: Nathaniel <[email protected]>
1 parent 05dc20d commit 73e6a8d

File tree

3 files changed

+119
-5
lines changed

3 files changed

+119
-5
lines changed

causalpy/experiments/instrumental_variable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ class InstrumentalVariable(BaseExperiment):
5151
If priors are not specified we will substitute MLE estimates for
5252
the beta coefficients. Example: ``priors = {"mus": [0, 0],
5353
"sigmas": [1, 1], "eta": 2, "lkj_sd": 2}``.
54-
:param vs_prior_type : str or None, default=None
54+
vs_prior_type : str or None, default=None
5555
Type of variable selection prior: 'spike_and_slab', 'horseshoe', or None.
5656
If None, uses standard normal priors.
57-
:param vs_hyperparams : dict, optional
57+
vs_hyperparams : dict, optional
5858
Hyperparameters for variable selection priors. Only used if vs_prior_type
5959
is not None.
6060
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2022 - 2025 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import pymc as pm
17+
import pytest
18+
19+
from causalpy.variable_selection_priors import (
20+
HorseshoePrior,
21+
SpikeAndSlabPrior,
22+
VariableSelectionPrior,
23+
create_variable_selection_prior,
24+
)
25+
26+
27+
@pytest.fixture
28+
def sample_data():
29+
"""Generate sample design matrix for testing."""
30+
rng = np.random.default_rng(42)
31+
n_obs = 100
32+
n_features = 5
33+
X = rng.normal(size=(n_obs, n_features))
34+
return X
35+
36+
37+
@pytest.fixture
38+
def coords():
39+
"""Generate sample coordinates for PyMC models."""
40+
return {"features": [f"x_{i}" for i in range(5)]}
41+
42+
43+
def test_create_variable_in_model_context(coords):
44+
"""Test that create_variable works in PyMC model context."""
45+
prior = SpikeAndSlabPrior(dims="features")
46+
47+
with pm.Model(coords=coords) as model:
48+
beta = prior.create_variable("beta")
49+
50+
# Check that beta was created
51+
assert "beta" in model.named_vars
52+
assert beta.name == "beta"
53+
54+
# Check that intermediate variables were created
55+
assert "pi_beta" in model.named_vars
56+
assert "beta_raw" in model.named_vars
57+
assert "gamma_beta" in model.named_vars
58+
59+
60+
def test_create_variable_in_model_context_horseshoe(coords):
61+
"""Test that create_variable works in PyMC model context."""
62+
prior = HorseshoePrior(dims="features")
63+
64+
with pm.Model(coords=coords) as model:
65+
beta = prior.create_variable("beta")
66+
67+
# Check that beta was created
68+
assert "beta" in model.named_vars
69+
assert beta.name == "beta"
70+
71+
# Check that intermediate variables were created
72+
assert "tau_beta" in model.named_vars
73+
assert "lambda_beta" in model.named_vars
74+
assert "c2_beta" in model.named_vars
75+
assert "lambda_tilde_beta" in model.named_vars
76+
assert "beta_raw" in model.named_vars
77+
78+
79+
def test_create_prior_spike_and_slab(coords):
80+
"""Test create_prior for spike-and-slab."""
81+
vs_prior = VariableSelectionPrior("spike_and_slab")
82+
83+
with pm.Model(coords=coords) as model:
84+
beta = vs_prior.create_prior(name="beta", n_params=5, dims="features")
85+
86+
assert "beta" in model.named_vars
87+
assert beta.name == "beta"
88+
89+
90+
def test_create_prior_horseshoe(coords, sample_data):
91+
"""Test create_prior for horseshoe."""
92+
vs_prior = VariableSelectionPrior("horseshoe")
93+
94+
with pm.Model(coords=coords) as model:
95+
beta = vs_prior.create_prior(
96+
name="beta", n_params=5, dims="features", X=sample_data
97+
)
98+
99+
assert "beta" in model.named_vars
100+
assert beta.name == "beta"
101+
102+
103+
def test_convenience_function_with_custom_hyperparams(coords):
104+
"""Test convenience function with custom hyperparameters."""
105+
with pm.Model(coords=coords) as model:
106+
_ = create_variable_selection_prior(
107+
prior_type="spike_and_slab",
108+
name="beta",
109+
n_params=5,
110+
dims="features",
111+
hyperparams={"slab_sigma": 5},
112+
)
113+
114+
assert "beta" in model.named_vars

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)