@@ -5,9 +5,9 @@ jupytext:
55 format_name : myst
66 format_version : 0.13
77kernelspec :
8- display_name : Python 3 (ipykernel)
8+ display_name : pymc-examples
99 language : python
10- name : python3
10+ name : pymc-examples
1111---
1212
1313(spline)=
@@ -43,14 +43,15 @@ import numpy as np
4343import pandas as pd
4444import pymc as pm
4545
46- from patsy import dmatrix
46+ from patsy import build_design_matrices, dmatrix
4747```
4848
4949``` {code-cell} ipython3
5050%matplotlib inline
5151%config InlineBackend.figure_format = "retina"
5252
53- RANDOM_SEED = 8927
53+ seed = sum(map(ord, "splines"))
54+ rng = np.random.default_rng(seed)
5455az.style.use("arviz-darkgrid")
5556```
5657
@@ -84,7 +85,12 @@ If we visualize the data, it is clear that there a lot of annual variation, but
8485
8586``` {code-cell} ipython3
8687blossom_data.plot.scatter(
87- "year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Days in bloom"
88+ "year",
89+ "doy",
90+ color="cornflowerblue",
91+ s=10,
92+ title="Cherry Blossom Data",
93+ ylabel="Days in bloom",
8894);
8995```
9096
@@ -106,18 +112,23 @@ The spline will have 15 *knots*, splitting the year into 16 sections (including
106112
107113``` {code-cell} ipython3
108114num_knots = 15
109- knot_list = np.quantile (blossom_data.year, np.linspace(0, 1 , num_knots))
115+ knot_list = np.percentile (blossom_data.year, np.linspace(0, 100 , num_knots + 2))[1:-1]
110116knot_list
111117```
112118
113119Below is a plot of the locations of the knots over the data.
114120
115121``` {code-cell} ipython3
116122blossom_data.plot.scatter(
117- "year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Day of Year"
123+ "year",
124+ "doy",
125+ color="cornflowerblue",
126+ s=10,
127+ title="Cherry Blossom Data",
128+ ylabel="Day of Year",
118129)
119130for knot in knot_list:
120- plt.gca().axvline(knot, color="grey", alpha=0.4);
131+ plt.gca().axvline(knot, color="grey", alpha=0.4)
121132```
122133
123134We can use ` patsy ` to create the matrix $B$ that will be the b-spline basis for the regression.
@@ -128,7 +139,7 @@ The degree is set to 3 to create a cubic b-spline.
128139
129140B = dmatrix(
130141 "bs(year, knots=knots, degree=3, include_intercept=True) - 1",
131- {"year": blossom_data.year.values, "knots": knot_list[1:-1] },
142+ {"year": blossom_data.year.values, "knots": knot_list},
132143)
133144B
134145```
@@ -160,9 +171,14 @@ COORDS = {"splines": np.arange(B.shape[1])}
160171with pm.Model(coords=COORDS) as spline_model:
161172 a = pm.Normal("a", 100, 5)
162173 w = pm.Normal("w", mu=0, sigma=3, size=B.shape[1], dims="splines")
163- mu = pm.Deterministic("mu", a + pm.math.dot(np.asarray(B, order="F"), w.T))
174+
175+ mu = pm.Deterministic(
176+ "mu",
177+ a + pm.math.dot(np.asarray(B, order="F"), w.T),
178+ )
164179 sigma = pm.Exponential("sigma", 1)
165- D = pm.Normal("D", mu=mu, sigma=sigma, observed=blossom_data.doy, dims="obs")
180+
181+ D = pm.Normal("D", mu=mu, sigma=sigma, observed=blossom_data.doy)
166182```
167183
168184``` {code-cell} ipython3
@@ -172,7 +188,15 @@ pm.model_to_graphviz(spline_model)
172188``` {code-cell} ipython3
173189with spline_model:
174190 idata = pm.sample_prior_predictive()
175- idata.extend(pm.sample(draws=1000, tune=1000, random_seed=RANDOM_SEED, chains=4))
191+ idata.extend(
192+ pm.sample(
193+ nuts_sampler="nutpie",
194+ draws=1000,
195+ tune=1000,
196+ random_seed=rng,
197+ chains=4,
198+ )
199+ )
176200 pm.sample_posterior_predictive(idata, extend_inferencedata=True)
177201```
178202
@@ -230,7 +254,7 @@ spline_df_merged.plot("year", "value", c="black", lw=2, ax=plt.gca())
230254plt.legend(title="Spline Index", loc="lower center", fontsize=8, ncol=6)
231255
232256for knot in knot_list:
233- plt.gca().axvline(knot, color="grey", alpha=0.4);
257+ plt.gca().axvline(knot, color="grey", alpha=0.4)
234258```
235259
236260### Model predictions
@@ -267,6 +291,150 @@ plt.fill_between(
267291);
268292```
269293
294+ ## Predicting on new data
295+
296+ Now imagine we got a new data set, with the same range of years as the original data set, and we want to get predictions for this new data set with our fitted model. We can do this with the classic PyMC workflow of ` Data ` containers and ` set_data ` method.
297+
298+ Before we get there though, let's note that we didn't say the new data set contains * new* years, i.e out-of-sample years. And that's on purpose, because splines can't extrapolate beyond the range of the data set used to fit the model -- hence their limitation for time series analysis. On data ranges previously seen though, that's no problem.
299+
300+ That precision out of the way, let's redefine our model, this time adding ` Data ` containers.
301+
302+ ``` {code-cell} ipython3
303+ COORDS = {"obs": blossom_data.index}
304+ ```
305+
306+ ``` {code-cell} ipython3
307+ with pm.Model(coords=COORDS) as spline_model:
308+ year_data = pm.Data("year", blossom_data.year)
309+ doy = pm.Data("doy", blossom_data.doy)
310+
311+ # intercept
312+ a = pm.Normal("a", 100, 5)
313+
314+ # Create spline bases & coefficients
315+ ## Store knots & design matrix for prediction
316+ spline_model.knots = np.percentile(year_data.eval(), np.linspace(0, 100, num_knots + 2))[1:-1]
317+ spline_model.dm = dmatrix(
318+ "bs(x, knots=spline_model.knots, degree=3, include_intercept=False) - 1",
319+ {"x": year_data.eval()},
320+ )
321+ spline_model.add_coords({"spline": np.arange(spline_model.dm.shape[1])})
322+ splines_basis = pm.Data("splines_basis", np.asarray(spline_model.dm), dims=("obs", "spline"))
323+ w = pm.Normal("w", mu=0, sigma=3, dims="spline")
324+
325+ mu = pm.Deterministic(
326+ "mu",
327+ a + pm.math.dot(splines_basis, w),
328+ )
329+ sigma = pm.Exponential("sigma", 1)
330+
331+ D = pm.Normal("D", mu=mu, sigma=sigma, observed=doy)
332+ ```
333+
334+ ``` {code-cell} ipython3
335+ pm.model_to_graphviz(spline_model)
336+ ```
337+
338+ ``` {code-cell} ipython3
339+ with spline_model:
340+ idata = pm.sample(
341+ nuts_sampler="nutpie",
342+ random_seed=rng,
343+ )
344+ idata.extend(pm.sample_posterior_predictive(idata, random_seed=rng))
345+ ```
346+
347+ Now we can swap out the data and update the design matrix with the new data:
348+
349+ ``` {code-cell} ipython3
350+ new_blossom_data = (
351+ blossom_data.sample(50, random_state=rng).sort_values("year").reset_index(drop=True)
352+ )
353+
354+ # update design matrix with new data
355+ year_data_new = new_blossom_data.year.to_numpy()
356+ dm_new = build_design_matrices([spline_model.dm.design_info], {"x": year_data_new})[0]
357+ ```
358+
359+ Use ` set_data ` to update the data in the model:
360+
361+ ``` {code-cell} ipython3
362+ with spline_model:
363+ pm.set_data(
364+ new_data={
365+ "year": year_data_new,
366+ "doy": new_blossom_data.doy.to_numpy(),
367+ "splines_basis": np.asarray(dm_new),
368+ },
369+ coords={
370+ "obs": new_blossom_data.index,
371+ },
372+ )
373+ ```
374+
375+ And all that's left is to sample from the posterior predictive distribution:
376+
377+ ``` {code-cell} ipython3
378+ with spline_model:
379+ preds = pm.sample_posterior_predictive(idata, var_names=["mu"])
380+ ```
381+
382+ Plot the predictions, to check if everything went well:
383+
384+ ``` {code-cell} ipython3
385+ _, axes = plt.subplots(1, 2, figsize=(16, 5), sharex=True, sharey=True)
386+
387+ blossom_data.plot.scatter(
388+ "year",
389+ "doy",
390+ color="cornflowerblue",
391+ s=10,
392+ title="Posterior predictions",
393+ ylabel="Days in bloom",
394+ ax=axes[0],
395+ )
396+ axes[0].vlines(
397+ spline_model.knots,
398+ blossom_data.doy.min(),
399+ blossom_data.doy.max(),
400+ color="grey",
401+ alpha=0.4,
402+ )
403+ axes[0].plot(
404+ blossom_data.year,
405+ idata.posterior["mu"].mean(("chain", "draw")),
406+ color="firebrick",
407+ )
408+ az.plot_hdi(blossom_data.year, idata.posterior["mu"], color="firebrick", ax=axes[0])
409+
410+ new_blossom_data.plot.scatter(
411+ "year",
412+ "doy",
413+ color="cornflowerblue",
414+ s=10,
415+ title="Predictions on new data",
416+ ylabel="Days in bloom",
417+ ax=axes[1],
418+ )
419+ axes[1].vlines(
420+ spline_model.knots,
421+ blossom_data.doy.min(),
422+ blossom_data.doy.max(),
423+ color="grey",
424+ alpha=0.4,
425+ )
426+ axes[1].plot(
427+ new_blossom_data.year,
428+ preds.posterior_predictive.mu.mean(("chain", "draw")),
429+ color="firebrick",
430+ )
431+ az.plot_hdi(new_blossom_data.year, preds.posterior_predictive.mu, color="firebrick", ax=axes[1]);
432+ ```
433+
434+ And... voilà! Granted, this example is not the most realistic one, but we trust you to adapt it to your wildest dreams ;)
435+
436+ +++
437+
270438## References
271439
272440:::{bibliography}
@@ -280,6 +448,7 @@ plt.fill_between(
280448- Created by Joshua Cook
281449- Updated by Tyler James Burch
282450- Updated by Chris Fonnesbeck
451+ - Predictions on new data added by Alex Andorra
283452
284453+++
285454
0 commit comments