Skip to content

Commit 33d92b5

Browse files
authored
Merge pull request #109 from pymc-labs/add_basic_tests
Add tests Model Builder. Starts #66 on mypy, and #25 on tests
2 parents 152844a + 98306bc commit 33d92b5

File tree

5 files changed

+102
-12
lines changed

5 files changed

+102
-12
lines changed

.isort.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
[settings]
2-
known_third_party = arviz,matplotlib,numpy,pandas,patsy,pymc,scipy,seaborn,setuptools,sklearn,statsmodels,xarray
2+
known_third_party = arviz,matplotlib,numpy,pandas,patsy,pymc,pytest,scipy,seaborn,setuptools,sklearn,statsmodels,xarray

causalpy/pymc_models.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import Dict
1+
from typing import Any, Dict, Optional
22

33
import arviz as az
44
import numpy as np
5+
import pandas as pd
56
import pymc as pm
67
from arviz import r2_score
78

@@ -11,19 +12,33 @@ class ModelBuilder(pm.Model):
1112
This is a wrapper around pm.Model to give scikit-learn like API
1213
"""
1314

14-
def __init__(self, sample_kwargs: Dict = {}):
15+
def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None):
1516
super().__init__()
1617
self.idata = None
17-
self.sample_kwargs = sample_kwargs
18-
19-
def build_model(self, X, y, coords):
20-
raise NotImplementedError
18+
self.sample_kwargs = sample_kwargs if sample_kwargs is not None else {}
19+
20+
def build_model(self, X, y, coords) -> None:
21+
"""Build the model.
22+
23+
Example
24+
-------
25+
>>> class CausalPyModel(ModelBuilder):
26+
>>> def build_model(self, X, y):
27+
>>> with self:
28+
>>> X_ = pm.MutableData(name="X", value=X)
29+
>>> y_ = pm.MutableData(name="y", value=y)
30+
>>> beta = pm.Normal("beta", mu=0, sigma=1, shape=X_.shape[1])
31+
>>> sigma = pm.HalfNormal("sigma", sigma=1)
32+
>>> mu = pm.Deterministic("mu", pm.math.dot(X_, beta))
33+
>>> pm.Normal("y_hat", mu=mu, sigma=sigma, observed=y_)
34+
"""
35+
raise NotImplementedError("This method must be implemented by a subclass")
2136

22-
def _data_setter(self, X):
37+
def _data_setter(self, X) -> None:
2338
with self.model:
2439
pm.set_data({"X": X})
2540

26-
def fit(self, X, y, coords):
41+
def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
2742
"""Draw samples from posterior, prior predictive, and posterior predictive
2843
distributions.
2944
"""
@@ -43,7 +58,7 @@ def predict(self, X):
4358
)
4459
return post_pred
4560

46-
def score(self, X, y):
61+
def score(self, X, y) -> pd.Series:
4762
"""Score the Bayesian :math:`R^2` given inputs ``X`` and outputs ``y``.
4863
4964
.. caution::

causalpy/tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import numpy as np
2+
import pytest
3+
4+
5+
@pytest.fixture(scope="session")
6+
def rng() -> np.random.Generator:
7+
seed: int = sum(map(ord, "causalpy"))
8+
return np.random.default_rng(seed=seed)

causalpy/tests/test_dummy.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

causalpy/tests/test_pymc_models.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import arviz as az
2+
import numpy as np
3+
import pandas as pd
4+
import pymc as pm
5+
import pytest
6+
7+
from causalpy.pymc_models import ModelBuilder
8+
9+
10+
class MyToyModel(ModelBuilder):
11+
def build_model(self, X, y, coords):
12+
with self:
13+
X_ = pm.MutableData(name="X", value=X)
14+
y_ = pm.MutableData(name="y", value=y)
15+
beta = pm.Normal("beta", mu=0, sigma=1, shape=X_.shape[1])
16+
sigma = pm.HalfNormal("sigma", sigma=1)
17+
mu = pm.Deterministic("mu", pm.math.dot(X_, beta))
18+
pm.Normal("y_hat", mu=mu, sigma=sigma, observed=y_)
19+
20+
21+
class TestModelBuilder:
22+
def test_init(self):
23+
mb = ModelBuilder()
24+
assert mb.idata is None
25+
assert mb.sample_kwargs == {}
26+
27+
@pytest.mark.parametrize(
28+
argnames="coords", argvalues=[{"a": 1}, None], ids=["coords-dict", "coord-None"]
29+
)
30+
@pytest.mark.parametrize(
31+
argnames="y", argvalues=[np.ones(3), None], ids=["y-array", "y-None"]
32+
)
33+
@pytest.mark.parametrize(
34+
argnames="X", argvalues=[np.ones(2), None], ids=["X-array", "X-None"]
35+
)
36+
def test_model_builder(self, X, y, coords) -> None:
37+
with pytest.raises(
38+
NotImplementedError, match="This method must be implemented by a subclass"
39+
):
40+
ModelBuilder().build_model(X=X, y=y, coords=coords)
41+
42+
def test_fit_build_not_implemented(self):
43+
with pytest.raises(
44+
NotImplementedError, match="This method must be implemented by a subclass"
45+
):
46+
ModelBuilder().fit(X=np.ones(2), y=np.ones(3), coords={"a": 1})
47+
48+
@pytest.mark.parametrize(
49+
argnames="coords",
50+
argvalues=[None, {"a": 1}],
51+
ids=["None-coords", "dict-coords"],
52+
)
53+
def test_fit_predict(self, coords, rng) -> None:
54+
X = rng.normal(loc=0, scale=1, size=(20, 2))
55+
y = rng.normal(loc=0, scale=1, size=(20,))
56+
model = MyToyModel(sample_kwargs={"chains": 2, "draws": 2})
57+
model.fit(X, y, coords=coords)
58+
predictions = model.predict(X=X)
59+
score = model.score(X=X, y=y)
60+
assert isinstance(model.idata, az.InferenceData)
61+
assert az.extract(data=model.idata, var_names=["beta"]).shape == (2, 2 * 2)
62+
assert az.extract(data=model.idata, var_names=["sigma"]).shape == (2 * 2,)
63+
assert az.extract(data=model.idata, var_names=["mu"]).shape == (20, 2 * 2)
64+
assert az.extract(
65+
data=model.idata, group="posterior_predictive", var_names=["y_hat"]
66+
).shape == (20, 2 * 2)
67+
assert isinstance(score, pd.Series)
68+
assert score.shape == (2,)
69+
assert isinstance(predictions, az.InferenceData)

0 commit comments

Comments
 (0)