|  | 
| 34 | 34 |    "metadata": {}, | 
| 35 | 35 |    "outputs": [], | 
| 36 | 36 |    "source": [ | 
| 37 |  | -    "import pytensor\n", | 
| 38 |  | -    "import pytensor.tensor as pt\n", | 
| 39 |  | -    "import pymc as pm\n", | 
| 40 |  | -    "import numpy as np\n", | 
| 41 |  | -    "import nutpie\n", | 
| 42 | 37 |     "import arviz\n", | 
|  | 38 | +    "import matplotlib.pyplot as plt\n", | 
|  | 39 | +    "import numpy as np\n", | 
| 43 | 40 |     "import pandas as pd\n", | 
|  | 41 | +    "import pymc as pm\n", | 
| 44 | 42 |     "import seaborn as sns\n", | 
| 45 |  | -    "import matplotlib.pyplot as plt" | 
|  | 43 | +    "\n", | 
|  | 44 | +    "import nutpie" | 
| 46 | 45 |    ] | 
| 47 | 46 |   }, | 
| 48 | 47 |   { | 
|  | 
| 404 | 403 |     "    filename=\"radon_model.stan\",\n", | 
| 405 | 404 |     "    coords=coords_stan,\n", | 
| 406 | 405 |     "    dims=dims_stan,\n", | 
| 407 |  | -    "    cache=False\n", | 
|  | 406 | +    "    cache=False,\n", | 
| 408 | 407 |     ")" | 
| 409 | 408 |    ] | 
| 410 | 409 |   }, | 
|  | 
| 484 | 483 |    "metadata": {}, | 
| 485 | 484 |    "outputs": [], | 
| 486 | 485 |    "source": [ | 
| 487 |  | -    "import stan\n", | 
| 488 | 486 |     "import nest_asyncio\n", | 
|  | 487 | +    "import stan\n", | 
| 489 | 488 |     "\n", | 
| 490 | 489 |     "nest_asyncio.apply()" | 
| 491 | 490 |    ] | 
|  | 
| 522 | 521 |    ], | 
| 523 | 522 |    "source": [ | 
| 524 | 523 |     "%%time\n", | 
| 525 |  | -    "with open(\"radon_model.stan\", \"r\") as file:\n", | 
|  | 524 | +    "with open(\"radon_model.stan\") as file:\n", | 
| 526 | 525 |     "    model = stan.build(file.read(), data=data_stan)" | 
| 527 | 526 |    ] | 
| 528 | 527 |   }, | 
|  | 
| 748 | 747 |     } | 
| 749 | 748 |    ], | 
| 750 | 749 |    "source": [ | 
| 751 |  | -    "plt.plot((trace_pymc.warmup_sample_stats.n_steps).isel(draw=slice(0, 1000)).cumsum(\"draw\").T, np.log(trace_pymc.warmup_sample_stats.energy.isel(draw=slice(0, 1000)).T));\n", | 
|  | 750 | +    "plt.plot(\n", | 
|  | 751 | +    "    (trace_pymc.warmup_sample_stats.n_steps).isel(draw=slice(0, 1000)).cumsum(\"draw\").T,\n", | 
|  | 752 | +    "    np.log(trace_pymc.warmup_sample_stats.energy.isel(draw=slice(0, 1000)).T),\n", | 
|  | 753 | +    ")\n", | 
| 752 | 754 |     "plt.xlim(0, 10000)\n", | 
| 753 | 755 |     "plt.ylabel(\"log-energy\")\n", | 
| 754 | 756 |     "plt.xlabel(\"gradient evaluations\");" | 
|  | 
| 782 | 784 |     } | 
| 783 | 785 |    ], | 
| 784 | 786 |    "source": [ | 
| 785 |  | -    "plt.plot((trace_cmdstan.warmup_sample_stats.n_steps).isel(draw=slice(0, 1000)).cumsum(\"draw\").T, np.log(trace_cmdstan.warmup_sample_stats.energy.isel(draw=slice(0, 1000)).T));\n", | 
|  | 787 | +    "plt.plot(\n", | 
|  | 788 | +    "    (trace_cmdstan.warmup_sample_stats.n_steps)\n", | 
|  | 789 | +    "    .isel(draw=slice(0, 1000))\n", | 
|  | 790 | +    "    .cumsum(\"draw\")\n", | 
|  | 791 | +    "    .T,\n", | 
|  | 792 | +    "    np.log(trace_cmdstan.warmup_sample_stats.energy.isel(draw=slice(0, 1000)).T),\n", | 
|  | 793 | +    ")\n", | 
| 786 | 794 |     "plt.xlim(0, 10000)\n", | 
| 787 | 795 |     "plt.ylabel(\"log-energy\")\n", | 
| 788 | 796 |     "plt.xlabel(\"gradient evaluations\");" | 
|  | 
| 1607 | 1615 |     } | 
| 1608 | 1616 |    ], | 
| 1609 | 1617 |    "source": [ | 
| 1610 |  | -    "type({name: int(val) if isinstance(val, int) else list(val) for name, val in data_stan.items()}[\"county_idx\"][0])" | 
|  | 1618 | +    "type(\n", | 
|  | 1619 | +    "    {\n", | 
|  | 1620 | +    "        name: int(val) if isinstance(val, int) else list(val)\n", | 
|  | 1621 | +    "        for name, val in data_stan.items()\n", | 
|  | 1622 | +    "    }[\"county_idx\"][0]\n", | 
|  | 1623 | +    ")" | 
| 1611 | 1624 |    ] | 
| 1612 | 1625 |   }, | 
| 1613 | 1626 |   { | 
|  | 
| 1622 | 1635 |     "    if isinstance(val, int):\n", | 
| 1623 | 1636 |     "        data_json[name] = int(val)\n", | 
| 1624 | 1637 |     "        continue\n", | 
| 1625 |  | -    "    \n", | 
|  | 1638 | +    "\n", | 
| 1626 | 1639 |     "    if val.dtype == np.int64:\n", | 
| 1627 | 1640 |     "        data_json[name] = list(int(x) for x in val)\n", | 
| 1628 | 1641 |     "        continue\n", | 
| 1629 |  | -    "    \n", | 
|  | 1642 | +    "\n", | 
| 1630 | 1643 |     "    data_json[name] = list(val)\n", | 
| 1631 |  | -    "    \n", | 
|  | 1644 | +    "\n", | 
| 1632 | 1645 |     "with open(\"radon.json\", \"w\") as file:\n", | 
| 1633 | 1646 |     "    json.dump(data_json, file)" | 
| 1634 | 1647 |    ] | 
|  | 
0 commit comments