Skip to content

Commit b8e5751

Browse files
committed
Merge branch 'main' into pre-commit-ci-update-config
2 parents e65049f + e61bfd0 commit b8e5751

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

causalpy/pymc_models.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,18 @@ class PyMCModel(pm.Model):
5454
... "chains": 2,
5555
... "draws": 2000,
5656
... "progressbar": False,
57-
... "random_seed": rng,
57+
... "random_seed": 42,
5858
... }
5959
... )
6060
>>> model.fit(X, y)
6161
Inference data...
62+
>>> model.score(X, y) # doctest: +ELLIPSIS
63+
r2 ...
64+
r2_std ...
65+
dtype: float64
6266
>>> X_new = rng.normal(loc=0, scale=1, size=(20,2))
6367
>>> model.predict(X_new)
6468
Inference data...
65-
>>> model.score(X, y)
66-
r2 0.390344
67-
r2_std 0.081135
68-
dtype: float64
6969
"""
7070

7171
def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None):
@@ -123,7 +123,6 @@ def predict(self, X):
123123
# Ensure random_seed is used in sample_prior_predictive() and
124124
# sample_posterior_predictive() if provided in sample_kwargs.
125125
random_seed = self.sample_kwargs.get("random_seed", None)
126-
127126
self._data_setter(X)
128127
with self: # sample with new input data
129128
post_pred = pm.sample_posterior_predictive(
@@ -137,18 +136,19 @@ def predict(self, X):
137136
def score(self, X, y) -> pd.Series:
138137
"""Score the Bayesian :math:`R^2` given inputs ``X`` and outputs ``y``.
139138
139+
Note that the score is based on a comparison of the observed data ``y`` and the
140+
model's expected value of the data, `mu`.
141+
140142
.. caution::
141143
142144
The Bayesian :math:`R^2` is not the same as the traditional coefficient of
143145
determination, https://en.wikipedia.org/wiki/Coefficient_of_determination.
144146
145147
"""
146-
yhat = self.predict(X)
147-
yhat = az.extract(
148-
yhat, group="posterior_predictive", var_names="y_hat"
149-
).T.values
148+
mu = self.predict(X)
149+
mu = az.extract(mu, group="posterior_predictive", var_names="mu").T.values
150150
# Note: First argument must be a 1D array
151-
return r2_score(y.flatten(), yhat)
151+
return r2_score(y.flatten(), mu)
152152

153153
def calculate_impact(self, y_true, y_pred):
154154
pre_data = xr.DataArray(y_true, dims=["obs_ind"])

0 commit comments

Comments
 (0)