Skip to content

Commit 8edf6a0

Browse files
committed
fix failing doctest, use posterior expectation for r2 score
1 parent 142694f commit 8edf6a0

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

causalpy/pymc_models.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class PyMCModel(pm.Model):
5454
... "chains": 2,
5555
... "draws": 2000,
5656
... "progressbar": False,
57-
... "random_seed": rng,
57+
... "random_seed": 42,
5858
... }
5959
... )
6060
>>> model.fit(X, y)
@@ -63,8 +63,8 @@ class PyMCModel(pm.Model):
6363
>>> model.predict(X_new)
6464
Inference data...
6565
>>> model.score(X, y)
66-
r2 0.390344
67-
r2_std 0.081135
66+
r2 0.19157
67+
r2_std 0.11238
6868
dtype: float64
6969
"""
7070

@@ -123,7 +123,6 @@ def predict(self, X):
123123
# Ensure random_seed is used in sample_prior_predictive() and
124124
# sample_posterior_predictive() if provided in sample_kwargs.
125125
random_seed = self.sample_kwargs.get("random_seed", None)
126-
127126
self._data_setter(X)
128127
with self: # sample with new input data
129128
post_pred = pm.sample_posterior_predictive(
@@ -137,18 +136,19 @@ def predict(self, X):
137136
def score(self, X, y) -> pd.Series:
138137
"""Score the Bayesian :math:`R^2` given inputs ``X`` and outputs ``y``.
139138
139+
Note that the score is based on a comparison of the observed data ``y`` and the
140+
model's expected value of the data, `mu`.
141+
140142
.. caution::
141143
142144
The Bayesian :math:`R^2` is not the same as the traditional coefficient of
143145
determination, https://en.wikipedia.org/wiki/Coefficient_of_determination.
144146
145147
"""
146-
yhat = self.predict(X)
147-
yhat = az.extract(
148-
yhat, group="posterior_predictive", var_names="y_hat"
149-
).T.values
148+
mu = self.predict(X)
149+
mu = az.extract(mu, group="posterior_predictive", var_names="mu").T.values
150150
# Note: First argument must be a 1D array
151-
return r2_score(y.flatten(), yhat)
151+
return r2_score(y.flatten(), mu)
152152

153153
def calculate_impact(self, y_true, y_pred):
154154
pre_data = xr.DataArray(y_true, dims=["obs_ind"])

0 commit comments

Comments
 (0)