Skip to content

Commit d3ea6ee

Browse files
committed
add tests base class
1 parent bca7a25 commit d3ea6ee

File tree

4 files changed

+100
-12
lines changed

4 files changed

+100
-12
lines changed

causalpy/pymc_models.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from typing import Dict
2-
1+
from typing import Any, Dict, Optional
2+
import pandas as pd
33
import arviz as az
44
import numpy as np
55
import pymc as pm
@@ -11,19 +11,33 @@ class ModelBuilder(pm.Model):
1111
This is a wrapper around pm.Model to give scikit-learn like API
1212
"""
1313

14-
def __init__(self, sample_kwargs: Dict = {}):
14+
def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None):
1515
super().__init__()
1616
self.idata = None
17-
self.sample_kwargs = sample_kwargs
18-
19-
def build_model(self, X, y, coords):
20-
raise NotImplementedError
17+
self.sample_kwargs = sample_kwargs if sample_kwargs is not None else {}
18+
19+
def build_model(self, X, y, coords) -> None:
20+
"""Build the model.
21+
22+
Example
23+
-------
24+
>>> class CausalPyModel(ModelBuilder):
25+
>>> def build_model(self, X, y):
26+
>>> with self:
27+
>>> X_ = pm.MutableData(name="X", value=X)
28+
>>> y_ = pm.MutableData(name="y", value=y)
29+
>>> beta = pm.Normal("beta", mu=0, sigma=1, shape=X_.shape[1])
30+
>>> sigma = pm.HalfNormal("sigma", sigma=1)
31+
>>> mu = pm.Deterministic("mu", pm.math.dot(X_, beta))
32+
>>> pm.Normal("y_hat", mu=mu, sigma=sigma, observed=y_)
33+
"""
34+
raise NotImplementedError("This method must be implemented by a subclass")
2135

22-
def _data_setter(self, X):
36+
def _data_setter(self, X) -> None:
2337
with self.model:
2438
pm.set_data({"X": X})
2539

26-
def fit(self, X, y, coords):
40+
def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
2741
"""Draw samples from posterior, prior predictive, and posterior predictive
2842
distributions.
2943
"""
@@ -43,7 +57,7 @@ def predict(self, X):
4357
)
4458
return post_pred
4559

46-
def score(self, X, y):
60+
def score(self, X, y) -> pd.Series:
4761
"""Score the Bayesian :math:`R^2` given inputs ``X`` and outputs ``y``.
4862
4963
.. caution::

causalpy/tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import numpy as np
2+
import pytest
3+
4+
5+
@pytest.fixture(scope="session")
6+
def rng() -> np.random.Generator:
7+
seed: int = sum(map(ord, "causalpy"))
8+
return np.random.default_rng(seed=seed)

causalpy/tests/test_dummy.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

causalpy/tests/test_pymc_models.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import numpy as np
2+
import pytest
3+
import pymc as pm
4+
from causalpy.pymc_models import ModelBuilder
5+
import arviz as az
6+
import pandas as pd
7+
8+
9+
class MyToyModel(ModelBuilder):
10+
def build_model(self, X, y, coords):
11+
with self:
12+
X_ = pm.MutableData(name="X", value=X)
13+
y_ = pm.MutableData(name="y", value=y)
14+
beta = pm.Normal("beta", mu=0, sigma=1, shape=X_.shape[1])
15+
sigma = pm.HalfNormal("sigma", sigma=1)
16+
mu = pm.Deterministic("mu", pm.math.dot(X_, beta))
17+
pm.Normal("y_hat", mu=mu, sigma=sigma, observed=y_)
18+
19+
20+
class TestModelBuilder:
21+
def test_init(self):
22+
mb = ModelBuilder()
23+
assert mb.idata is None
24+
assert mb.sample_kwargs == {}
25+
26+
@pytest.mark.parametrize(
27+
argnames="coords", argvalues=[{"a": 1}, None], ids=["coords-dict", "coord-None"]
28+
)
29+
@pytest.mark.parametrize(
30+
argnames="y", argvalues=[np.ones(3), None], ids=["y-array", "y-None"]
31+
)
32+
@pytest.mark.parametrize(
33+
argnames="X", argvalues=[np.ones(2), None], ids=["X-array", "X-None"]
34+
)
35+
def test_model_builder(self, X, y, coords) -> None:
36+
with pytest.raises(
37+
NotImplementedError, match="This method must be implemented by a subclass"
38+
):
39+
ModelBuilder().build_model(X=X, y=y, coords=coords)
40+
41+
def test_fit_build_not_implemented(self):
42+
with pytest.raises(
43+
NotImplementedError, match="This method must be implemented by a subclass"
44+
):
45+
ModelBuilder().fit(X=np.ones(2), y=np.ones(3), coords={"a": 1})
46+
47+
@pytest.mark.parametrize(
48+
argnames="coords",
49+
argvalues=[None, {"a": 1}],
50+
ids=["None-coords", "dict-coords"],
51+
)
52+
def test_fit_predict(self, coords, rng) -> None:
53+
X = rng.normal(loc=0, scale=1, size=(20, 2))
54+
y = rng.normal(loc=0, scale=1, size=(20,))
55+
model = MyToyModel(sample_kwargs={"chains": 2, "draws": 2})
56+
model.fit(X, y, coords=coords)
57+
predictions = model.predict(X=X)
58+
score = model.score(X=X, y=y)
59+
assert isinstance(model.idata, az.InferenceData)
60+
assert az.extract(data=model.idata, var_names=["beta"]).shape == (2, 2 * 2)
61+
assert az.extract(data=model.idata, var_names=["sigma"]).shape == (2 * 2,)
62+
assert az.extract(data=model.idata, var_names=["mu"]).shape == (20, 2 * 2)
63+
assert az.extract(
64+
data=model.idata, group="posterior_predictive", var_names=["y_hat"]
65+
).shape == (20, 2 * 2)
66+
assert isinstance(score, pd.Series)
67+
assert score.shape == (2,)
68+
assert isinstance(predictions, az.InferenceData)

0 commit comments

Comments
 (0)