1
+ from typing import Dict
2
+
1
3
import arviz as az
2
4
import numpy as np
3
5
import pymc as pm
@@ -9,9 +11,10 @@ class ModelBuilder(pm.Model):
9
11
This is a wrapper around pm.Model to give scikit-learn like API
10
12
"""
11
13
12
- def __init__ (self ):
14
+ def __init__ (self , sample_kwargs : Dict = {} ):
13
15
super ().__init__ ()
14
16
self .idata = None
17
+ self .sample_kwargs = sample_kwargs
15
18
16
19
def build_model (self , X , y , coords ):
17
20
raise NotImplementedError
@@ -26,7 +29,7 @@ def fit(self, X, y, coords):
26
29
"""
27
30
self .build_model (X , y , coords )
28
31
with self .model :
29
- self .idata = pm .sample ()
32
+ self .idata = pm .sample (** self . sample_kwargs )
30
33
self .idata .extend (pm .sample_prior_predictive ())
31
34
self .idata .extend (pm .sample_posterior_predictive (self .idata ))
32
35
return self .idata
@@ -69,7 +72,12 @@ def build_model(self, X, y, coords):
69
72
n_predictors = X .shape [1 ]
70
73
X = pm .MutableData ("X" , X , dims = ["obs_ind" , "coeffs" ])
71
74
y = pm .MutableData ("y" , y [:, 0 ], dims = "obs_ind" )
75
+ # TODO: There we should allow user-specified priors here
72
76
beta = pm .Dirichlet ("beta" , a = np .ones (n_predictors ), dims = "coeffs" )
77
+ # beta = pm.Dirichlet(
78
+ # name="beta", a=(1 / n_predictors) * np.ones(n_predictors),
79
+ # dims="coeffs"
80
+ # )
73
81
sigma = pm .HalfNormal ("sigma" , 1 )
74
82
mu = pm .Deterministic ("mu" , pm .math .dot (X , beta ), dims = "obs_ind" )
75
83
pm .Normal ("y_hat" , mu , sigma , observed = y , dims = "obs_ind" )
0 commit comments