Skip to content

Commit bb7767e

Browse files
authored
Update docs to make Predictive behavior more clear (#1850)
* Update docs to make Predictive behavior more clear * Reorder imports to hopefully fix CI issue * Another attempt to revert autoformatting applied by vscode
1 parent f478772 commit bb7767e

File tree

3 files changed

+22
-0
lines changed

3 files changed

+22
-0
lines changed

notebooks/source/bayesian_hierarchical_linear_regression.ipynb

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,13 @@
427427
"samples_predictive = predictive(random.PRNGKey(0), patient_code, Weeks, None)"
428428
]
429429
},
430+
{
431+
"cell_type": "markdown",
432+
"metadata": {},
433+
"source": [
434+
"Note that for [`Predictive`](http://num.pyro.ai/en/latest/utilities.html#numpyro.infer.util.Predictive) to work as expected, the response variable of the model (in this case, `FVC_obs`) must be set to `None`."
435+
]
436+
},
430437
{
431438
"cell_type": "markdown",
432439
"metadata": {},

notebooks/source/bayesian_regression.ipynb

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,6 +1507,14 @@
15071507
"ax.set(xlabel=\"Marriage rate\", ylabel=\"Divorce rate\", title=\"Predictions with 90% CI\");"
15081508
]
15091509
},
1510+
{
1511+
"cell_type": "markdown",
1512+
"metadata": {},
1513+
"source": [
1514+
"Note that for `Predictive` to work as expected, the response variable of the model (in this case, `divorce`) must be set to `None`.\n",
1515+
"In the code above this is done implicitly by not passing a value for `divorce` to the model in the call to `prior_predictive`, which due to the model definition, sets `divorce=None`."
1516+
]
1517+
},
15101518
{
15111519
"cell_type": "markdown",
15121520
"metadata": {

numpyro/infer/util.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,9 @@ class Predictive(object):
857857
The interface for the `Predictive` class is experimental, and
858858
might change in the future.
859859
860+
Note that for the predictive distribution to be returned as intended, observed
861+
variables in the model (constraining the likelihood term) must be set to `None` (see Example).
862+
860863
:param model: Python callable containing Pyro primitives.
861864
:param dict posterior_samples: dictionary of samples from the posterior.
862865
:param callable guide: optional guide to get posterior samples of sites not present
@@ -908,6 +911,10 @@ def model(X, y=None):
908911
predictive = Predictive(model, num_samples=1000)
909912
y_pred = predictive(rng_key, X)["obs"]
910913
914+
Note how above, no value for `y` is passed to `predictive`, resulting in `y`
915+
being set to `None`. Setting the observed variable(s) to `None` when using
916+
`Predictive` is required for the method to function as expected.
917+
911918
If you also have posterior samples, you can sample from the posterior predictive::
912919
913920
predictive = Predictive(model, posterior_samples=posterior_samples)

0 commit comments

Comments
 (0)