|
21 | 21 | "id": "91d1fcf3",
|
22 | 22 | "metadata": {},
|
23 | 23 | "source": [
|
24 |
| - "In This notebook, we show to model and fit a time series model starting from a generative graph. In particular, we explain how to use {func}`~pytensor.scan` to loop efficiently inside a PyMC model.\n", |
| 24 | + "In This notebook, we show to model and fit a time series model starting from a generative graph. In particular, we explain how to use {func}`~pytensor.scan.basic.scan` to loop efficiently inside a PyMC model.\n", |
25 | 25 | "\n",
|
26 | 26 | "For this example, we consider an autoregressive model AR(2). Recall that an AR(2) model is defined as:\n",
|
27 | 27 | "\n",
|
|
79 | 79 | "source": [
|
80 | 80 | "## Define AR(2) Process\n",
|
81 | 81 | "\n",
|
82 |
| - "We start by encoding the generative graph of the AR(2) model as a function `ar_dist`. The strategy is to pass this function as a custom distribution via {class}`~pm.CustomDist` inside a PyMC model. \n", |
| 82 | + "We start by encoding the generative graph of the AR(2) model as a function `ar_dist`. The strategy is to pass this function as a custom distribution via {class}`~pymc.CustomDist` inside a PyMC model. \n", |
83 | 83 | "\n",
|
84 |
| - "We need to specify the initial state (`ar_init`), the autoregressive coefficients (`rho`), and the standard deviation of the noise (`sigma`). Given such parameters, we can define the generative graph of the AR(2) model using the {func}`~pytensor.scan` operation." |
| 84 | + "We need to specify the initial state (`ar_init`), the autoregressive coefficients (`rho`), and the standard deviation of the noise (`sigma`). Given such parameters, we can define the generative graph of the AR(2) model using the {func}`~pytensor.scan.basic.scan` operation." |
85 | 85 | ]
|
86 | 86 | },
|
87 | 87 | {
|
|
173 | 173 | "<text text-anchor=\"middle\" x=\"310\" y=\"-153.7\" font-family=\"Times,serif\" font-size=\"14.00\">CustomDist_ar_dist</text>\n",
|
174 | 174 | "</g>\n",
|
175 | 175 | "<!-- rho->ar_dist -->\n",
|
176 |
| - "<g id=\"edge3\" class=\"edge\">\n", |
| 176 | + "<g id=\"edge4\" class=\"edge\">\n", |
177 | 177 | "<title>rho->ar_dist</title>\n",
|
178 | 178 | "<path fill=\"none\" stroke=\"black\" d=\"M189.45,-267.11C208.64,-252.64 233.46,-233.94 255.57,-217.27\"/>\n",
|
179 | 179 | "<polygon fill=\"black\" stroke=\"black\" points=\"257.45,-220.24 263.33,-211.42 253.24,-214.65 257.45,-220.24\"/>\n",
|
180 | 180 | "</g>\n",
|
181 |
| - "<!-- ar_init -->\n", |
182 |
| - "<g id=\"node2\" class=\"node\">\n", |
183 |
| - "<title>ar_init</title>\n", |
184 |
| - "<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"57\" cy=\"-292.57\" rx=\"41.01\" ry=\"40.66\"/>\n", |
185 |
| - "<text text-anchor=\"middle\" x=\"57\" y=\"-304.02\" font-family=\"Times,serif\" font-size=\"14.00\">ar_init</text>\n", |
186 |
| - "<text text-anchor=\"middle\" x=\"57\" y=\"-287.52\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", |
187 |
| - "<text text-anchor=\"middle\" x=\"57\" y=\"-271.02\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n", |
188 |
| - "</g>\n", |
189 | 181 | "<!-- ar_init_obs -->\n",
|
190 |
| - "<g id=\"node3\" class=\"node\">\n", |
| 182 | + "<g id=\"node2\" class=\"node\">\n", |
191 | 183 | "<title>ar_init_obs</title>\n",
|
192 | 184 | "<path fill=\"lightgrey\" stroke=\"black\" d=\"M91.62,-204C91.62,-204 28.38,-204 28.38,-204 22.38,-204 16.38,-198 16.38,-192 16.38,-192 16.38,-158.5 16.38,-158.5 16.38,-152.5 22.38,-146.5 28.38,-146.5 28.38,-146.5 91.62,-146.5 91.62,-146.5 97.62,-146.5 103.62,-152.5 103.62,-158.5 103.62,-158.5 103.62,-192 103.62,-192 103.62,-198 97.62,-204 91.62,-204\"/>\n",
|
193 | 185 | "<text text-anchor=\"middle\" x=\"60\" y=\"-186.7\" font-family=\"Times,serif\" font-size=\"14.00\">ar_init_obs</text>\n",
|
194 | 186 | "<text text-anchor=\"middle\" x=\"60\" y=\"-170.2\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
|
195 | 187 | "<text text-anchor=\"middle\" x=\"60\" y=\"-153.7\" font-family=\"Times,serif\" font-size=\"14.00\">MutableData</text>\n",
|
196 | 188 | "</g>\n",
|
| 189 | + "<!-- ar_init -->\n", |
| 190 | + "<g id=\"node3\" class=\"node\">\n", |
| 191 | + "<title>ar_init</title>\n", |
| 192 | + "<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"57\" cy=\"-292.57\" rx=\"41.01\" ry=\"40.66\"/>\n", |
| 193 | + "<text text-anchor=\"middle\" x=\"57\" y=\"-304.02\" font-family=\"Times,serif\" font-size=\"14.00\">ar_init</text>\n", |
| 194 | + "<text text-anchor=\"middle\" x=\"57\" y=\"-287.52\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", |
| 195 | + "<text text-anchor=\"middle\" x=\"57\" y=\"-271.02\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n", |
| 196 | + "</g>\n", |
197 | 197 | "<!-- ar_init->ar_init_obs -->\n",
|
198 | 198 | "<g id=\"edge1\" class=\"edge\">\n",
|
199 | 199 | "<title>ar_init->ar_init_obs</title>\n",
|
200 | 200 | "<path fill=\"none\" stroke=\"black\" d=\"M58.04,-251.56C58.35,-239.85 58.68,-227.07 58.98,-215.41\"/>\n",
|
201 | 201 | "<polygon fill=\"black\" stroke=\"black\" points=\"62.47,-215.81 59.23,-205.72 55.47,-215.63 62.47,-215.81\"/>\n",
|
202 | 202 | "</g>\n",
|
203 | 203 | "<!-- ar_init->ar_dist -->\n",
|
204 |
| - "<g id=\"edge4\" class=\"edge\">\n", |
| 204 | + "<g id=\"edge5\" class=\"edge\">\n", |
205 | 205 | "<title>ar_init->ar_dist</title>\n",
|
206 | 206 | "<path fill=\"none\" stroke=\"black\" d=\"M87.14,-264.52C93.4,-259.83 100.18,-255.36 107,-251.91 149.34,-230.51 165.52,-240.4 210,-223.91 220.35,-220.07 231.08,-215.56 241.5,-210.87\"/>\n",
|
207 | 207 | "<polygon fill=\"black\" stroke=\"black\" points=\"242.8,-214.13 250.42,-206.77 239.88,-207.77 242.8,-214.13\"/>\n",
|
|
229 | 229 | "<text text-anchor=\"middle\" x=\"310\" y=\"-271.02\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n",
|
230 | 230 | "</g>\n",
|
231 | 231 | "<!-- sigma->ar_dist -->\n",
|
232 |
| - "<g id=\"edge5\" class=\"edge\">\n", |
| 232 | + "<g id=\"edge3\" class=\"edge\">\n", |
233 | 233 | "<title>sigma->ar_dist</title>\n",
|
234 | 234 | "<path fill=\"none\" stroke=\"black\" d=\"M310,-251.56C310,-243.78 310,-235.52 310,-227.44\"/>\n",
|
235 | 235 | "<polygon fill=\"black\" stroke=\"black\" points=\"313.5,-227.7 310,-217.7 306.5,-227.7 313.5,-227.7\"/>\n",
|
|
258 | 258 | "</svg>\n"
|
259 | 259 | ],
|
260 | 260 | "text/plain": [
|
261 |
| - "<graphviz.graphs.Digraph at 0x127a93450>" |
| 261 | + "<graphviz.graphs.Digraph at 0x1353ab310>" |
262 | 262 | ]
|
263 | 263 | },
|
264 | 264 | "execution_count": 3,
|
|
535 | 535 | "\n",
|
536 | 536 | " <div>\n",
|
537 | 537 | " <progress value='8000' class='' max='8000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
538 |
| - " 100.00% [8000/8000 00:15<00:00 Sampling 4 chains, 0 divergences]\n", |
| 538 | + " 100.00% [8000/8000 00:17<00:00 Sampling 4 chains, 0 divergences]\n", |
539 | 539 | " </div>\n",
|
540 | 540 | " "
|
541 | 541 | ],
|
|
550 | 550 | "name": "stderr",
|
551 | 551 | "output_type": "stream",
|
552 | 552 | "text": [
|
553 |
| - "Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 26 seconds.\n" |
| 553 | + "Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 28 seconds.\n" |
554 | 554 | ]
|
555 | 555 | }
|
556 | 556 | ],
|
|
716 | 716 | "\n",
|
717 | 717 | " <div>\n",
|
718 | 718 | " <progress value='4000' class='' max='4000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
719 |
| - " 100.00% [4000/4000 00:29<00:00]\n", |
| 719 | + " 100.00% [4000/4000 00:30<00:00]\n", |
720 | 720 | " </div>\n",
|
721 | 721 | " "
|
722 | 722 | ],
|
|
881 | 881 | "\n",
|
882 | 882 | "pytensor: 2.18.6\n",
|
883 | 883 | "\n",
|
| 884 | + "arviz : 0.17.0\n", |
| 885 | + "matplotlib: 3.8.2\n", |
| 886 | + "pymc : 5.10.3\n", |
884 | 887 | "numpy : 1.26.3\n",
|
885 | 888 | "pytensor : 2.18.6\n",
|
886 |
| - "pymc : 5.10.3\n", |
887 |
| - "matplotlib: 3.8.2\n", |
888 |
| - "arviz : 0.17.0\n", |
889 | 889 | "\n",
|
890 | 890 | "Watermark: 2.4.3\n",
|
891 | 891 | "\n"
|
|
0 commit comments