@@ -54,18 +54,18 @@ class PyMCModel(pm.Model):
5454 ... "chains": 2,
5555 ... "draws": 2000,
5656 ... "progressbar": False,
57- ... "random_seed": rng ,
57+ ... "random_seed": 42 ,
5858 ... }
5959 ... )
6060 >>> model.fit(X, y)
6161 Inference data...
62+ >>> model.score(X, y) # doctest: +ELLIPSIS
63+ r2 ...
64+ r2_std ...
65+ dtype: float64
6266 >>> X_new = rng.normal(loc=0, scale=1, size=(20,2))
6367 >>> model.predict(X_new)
6468 Inference data...
65- >>> model.score(X, y)
66- r2 0.390344
67- r2_std 0.081135
68- dtype: float64
6969 """
7070
7171 def __init__ (self , sample_kwargs : Optional [Dict [str , Any ]] = None ):
@@ -123,7 +123,6 @@ def predict(self, X):
123123 # Ensure random_seed is used in sample_prior_predictive() and
124124 # sample_posterior_predictive() if provided in sample_kwargs.
125125 random_seed = self .sample_kwargs .get ("random_seed" , None )
126-
127126 self ._data_setter (X )
128127 with self : # sample with new input data
129128 post_pred = pm .sample_posterior_predictive (
@@ -137,18 +136,19 @@ def predict(self, X):
137136 def score (self , X , y ) -> pd .Series :
138137 """Score the Bayesian :math:`R^2` given inputs ``X`` and outputs ``y``.
139138
139+ Note that the score is based on a comparison of the observed data ``y`` and the
140+ model's expected value of the data, `mu`.
141+
140142 .. caution::
141143
142144 The Bayesian :math:`R^2` is not the same as the traditional coefficient of
143145 determination, https://en.wikipedia.org/wiki/Coefficient_of_determination.
144146
145147 """
146- yhat = self .predict (X )
147- yhat = az .extract (
148- yhat , group = "posterior_predictive" , var_names = "y_hat"
149- ).T .values
148+ mu = self .predict (X )
149+ mu = az .extract (mu , group = "posterior_predictive" , var_names = "mu" ).T .values
150150 # Note: First argument must be a 1D array
151- return r2_score (y .flatten (), yhat )
151+ return r2_score (y .flatten (), mu )
152152
153153 def calculate_impact (self , y_true , y_pred ):
154154 pre_data = xr .DataArray (y_true , dims = ["obs_ind" ])
0 commit comments