Skip to content

Commit d7c0888

Browse files
committed
first interation
1 parent 78cd6b9 commit d7c0888

File tree

2 files changed

+120
-78
lines changed

2 files changed

+120
-78
lines changed

examples/time_series/Time_Series_Generative_Graph.ipynb

Lines changed: 92 additions & 62 deletions
Large diffs are not rendered by default.

examples/time_series/Time_Series_Generative_Graph.myst.md

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ kernelspec:
1010
name: python3
1111
---
1212

13-
(arima_garch_1_1)=
13+
(time_series_generative_graph)=
1414
# Time Series Models Derived From a Generative Graph
1515

1616
:::{post} March, 2024
@@ -21,7 +21,7 @@ kernelspec:
2121

2222
+++
2323

24-
In This notebook, we show to model and fit a time series model starting from a generative graph. In particular, we explain how to use {class}`~pytensor.scan` to loop efficiently inside a PyMC model.
24+
In This notebook, we show to model and fit a time series model starting from a generative graph. In particular, we explain how to use {func}`~pytensor.scan` to loop efficiently inside a PyMC model.
2525

2626
For this example, we consider an autoregressive model AR(2). Recall that an AR(2) model is defined as:
2727

@@ -57,16 +57,16 @@ rng = np.random.default_rng(42)
5757

5858
## Define AR(2) Process
5959

60-
We start by encoding the generative graph of the AR(2) model as a function `ar_dist`. The strategy is to pass this function as a custom distribution via {class}`pm.CustomDist` inside a PyMC model.
60+
We start by encoding the generative graph of the AR(2) model as a function `ar_dist`. The strategy is to pass this function as a custom distribution via {class}`~pm.CustomDist` inside a PyMC model.
6161

62-
We need to specify the initial state (`ar_init`), the autoregressive coefficients (`rho`), and the standard deviation of the noise (`sigma`). Given such parameters, we can define the generative graph of the AR(2) model using the {class}`~pytensor.scan` operation.
62+
We need to specify the initial state (`ar_init`), the autoregressive coefficients (`rho`), and the standard deviation of the noise (`sigma`). Given such parameters, we can define the generative graph of the AR(2) model using the {func}`~pytensor.scan` operation.
6363

6464
```{code-cell} ipython3
6565
lags = 2 # Number of lags
6666
trials = 100 # Time series length
6767
6868
69-
def ar_dist(ar_init, rho, sigma):
69+
def ar_dist(ar_init, rho, sigma, size):
7070
def ar_step(x_tm2, x_tm1, rho, sigma):
7171
mu = x_tm1 * rho[0] + x_tm2 * rho[1]
7272
x = mu + pm.Normal.dist(sigma=sigma)
@@ -144,7 +144,7 @@ for i, hdi_prob in enumerate((0.94, 0.64), 1):
144144
)
145145
ax.plot(prior.prior["ar"].mean(("chain", "draw")), color="C0", label="Mean")
146146
ax.legend(loc="upper right")
147-
ax.set_xlabel("trials")
147+
ax.set_xlabel("time")
148148
ax.set_title("AR(2) Prior Samples", fontsize=18, fontweight="bold")
149149
```
150150

@@ -163,28 +163,37 @@ for i, axi in enumerate(ax, start=chosen_draw):
163163
color="C0" if i == chosen_draw else "black",
164164
)
165165
axi.set_title(f"Sample {i}", fontsize=18, fontweight="bold")
166-
ax[-1].set_xlabel("trials")
166+
ax[-1].set_xlabel("time")
167167
```
168168

169169
## Posterior
170170

171+
Next, we want to condition the AR(2) model on some observed data so that we can do a parameter recovery analysis.
172+
171173
```{code-cell} ipython3
174+
# Pick a random draw from the prior (i.e. a time series)
172175
prior_draw = prior.prior.isel(chain=0, draw=chosen_draw)
173176
177+
# Set the observed values
174178
ar_init_obs.set_value(prior_draw["ar"].values[:lags])
175179
ar_innov_obs.set_value(prior_draw["ar"].values[lags:])
176180
ar_obs = prior_draw["ar"].to_numpy()
177181
rho_true = prior_draw["rho"].to_numpy()
178182
sigma_true = prior_draw["sigma"].to_numpy()
179183
184+
# Output the true values
180185
print(f"rho_true={np.round(rho_true, 3)}, {sigma_true=:.3f}")
181186
```
182187

188+
We now run the MCMC algorithm to sample from the posterior distribution.
189+
183190
```{code-cell} ipython3
184191
with model:
185192
trace = pm.sample(random_seed=rng)
186193
```
187194

195+
Let's plot the trace and the posterior distribution of the parameters.
196+
188197
```{code-cell} ipython3
189198
axes = az.plot_trace(
190199
data=trace,
@@ -206,8 +215,14 @@ axes = az.plot_posterior(
206215
plt.gcf().suptitle("AR(2) Model Parameters Posterior", fontsize=18, fontweight="bold")
207216
```
208217

218+
We see we have successfully recovered the true parameters of the model.
219+
220+
+++
221+
209222
## Posterior Predictive
210223

224+
Finally, we can use the posterior samples to generate new data from the AR(2) model. We can then compare the generated data with the observed data to check the goodness of fit of the model.
225+
211226
```{code-cell} ipython3
212227
with model:
213228
post_pred = pm.sample_posterior_predictive(trace, random_seed=rng)
@@ -232,19 +247,23 @@ for i, hdi_prob in enumerate((0.94, 0.64), 1):
232247
ax.plot(prior.prior["ar"].mean(("chain", "draw")), color="C0", label="Mean")
233248
ax.plot(ar_obs, color="black", label="Observed")
234249
ax.legend(loc="upper right")
235-
ax.set_xlabel("trials")
250+
ax.set_xlabel("time")
236251
ax.set_title("AR(2) Posterior Predictive Samples", fontsize=18, fontweight="bold")
237252
```
238253

254+
Overall, we the model is capturing the global dynamics of the time series. In order to have a abetter insight of the model, we can plot a subset of the posterior samples and compare them with the observed data.
255+
239256
```{code-cell} ipython3
240257
fig, ax = plt.subplots(
241258
nrows=5, ncols=1, figsize=(12, 12), sharex=True, sharey=True, layout="constrained"
242259
)
243260
for i, axi in enumerate(ax):
244261
axi.plot(post_pred.posterior_predictive["ar"].isel(draw=i, chain=0), color="C0")
262+
axi.plot(ar_obs, color="black", label="Observed")
263+
axi.legend(loc="upper right")
245264
axi.set_title(f"Sample {i}")
246265
247-
ax[-1].set_xlabel("trials")
266+
ax[-1].set_xlabel("time")
248267
249268
fig.suptitle("AR(2) Posterior Predictive Samples", fontsize=18, fontweight="bold", y=1.05)
250269
```
@@ -254,13 +273,6 @@ fig.suptitle("AR(2) Posterior Predictive Samples", fontsize=18, fontweight="bold
254273

255274
+++
256275

257-
## References
258-
:::{bibliography}
259-
:filter: docname in docnames
260-
:::
261-
262-
+++
263-
264276
## Watermark
265277

266278
```{code-cell} ipython3

0 commit comments

Comments
 (0)