Skip to content

Commit 09829f3

Browse files
committed
user can provide sample_kwargs dict to pymc models
1 parent 45da14e commit 09829f3

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

causalpy/pymc_models.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Dict
2+
13
import arviz as az
24
import numpy as np
35
import pymc as pm
@@ -9,9 +11,10 @@ class ModelBuilder(pm.Model):
911
This is a wrapper around pm.Model to give scikit-learn like API
1012
"""
1113

12-
def __init__(self):
14+
def __init__(self, sample_kwargs: Dict = {}):
1315
super().__init__()
1416
self.idata = None
17+
self.sample_kwargs = sample_kwargs
1518

1619
def build_model(self, X, y, coords):
1720
raise NotImplementedError
@@ -26,7 +29,7 @@ def fit(self, X, y, coords):
2629
"""
2730
self.build_model(X, y, coords)
2831
with self.model:
29-
self.idata = pm.sample()
32+
self.idata = pm.sample(**self.sample_kwargs)
3033
self.idata.extend(pm.sample_prior_predictive())
3134
self.idata.extend(pm.sample_posterior_predictive(self.idata))
3235
return self.idata
@@ -69,7 +72,12 @@ def build_model(self, X, y, coords):
6972
n_predictors = X.shape[1]
7073
X = pm.MutableData("X", X, dims=["obs_ind", "coeffs"])
7174
y = pm.MutableData("y", y[:, 0], dims="obs_ind")
75+
# TODO: There we should allow user-specified priors here
7276
beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs")
77+
# beta = pm.Dirichlet(
78+
# name="beta", a=(1 / n_predictors) * np.ones(n_predictors),
79+
# dims="coeffs"
80+
# )
7381
sigma = pm.HalfNormal("sigma", 1)
7482
mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind")
7583
pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind")

0 commit comments

Comments
 (0)