Skip to content

Commit 185a76d

Browse files
author
Juan Orduz
committed
auto-fix
1 parent 5b88163 commit 185a76d

File tree

3 files changed

+31
-18
lines changed

3 files changed

+31
-18
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ ci:
33

44
repos:
55
- repo: https://github.com/pre-commit/pre-commit-hooks
6-
rev: v4.6.0
6+
rev: v5.0.0
77
hooks:
88
- id: debug-statements
99
- id: check-merge-conflict
@@ -16,7 +16,7 @@ repos:
1616
- id: trailing-whitespace
1717

1818
- repo: https://github.com/astral-sh/ruff-pre-commit
19-
rev: v0.5.6
19+
rev: v0.8.3
2020
hooks:
2121
- id: ruff
2222
args: ["--fix", "--output-format=full"]

notebooks/pytensor_logp.ipynb

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,14 @@
3434
"metadata": {},
3535
"outputs": [],
3636
"source": [
37-
"import pytensor\n",
38-
"import pytensor.tensor as pt\n",
39-
"import pymc as pm\n",
40-
"import numpy as np\n",
41-
"import nutpie\n",
4237
"import arviz\n",
38+
"import matplotlib.pyplot as plt\n",
39+
"import numpy as np\n",
4340
"import pandas as pd\n",
41+
"import pymc as pm\n",
4442
"import seaborn as sns\n",
45-
"import matplotlib.pyplot as plt"
43+
"\n",
44+
"import nutpie"
4645
]
4746
},
4847
{
@@ -404,7 +403,7 @@
404403
" filename=\"radon_model.stan\",\n",
405404
" coords=coords_stan,\n",
406405
" dims=dims_stan,\n",
407-
" cache=False\n",
406+
" cache=False,\n",
408407
")"
409408
]
410409
},
@@ -484,8 +483,8 @@
484483
"metadata": {},
485484
"outputs": [],
486485
"source": [
487-
"import stan\n",
488486
"import nest_asyncio\n",
487+
"import stan\n",
489488
"\n",
490489
"nest_asyncio.apply()"
491490
]
@@ -522,7 +521,7 @@
522521
],
523522
"source": [
524523
"%%time\n",
525-
"with open(\"radon_model.stan\", \"r\") as file:\n",
524+
"with open(\"radon_model.stan\") as file:\n",
526525
" model = stan.build(file.read(), data=data_stan)"
527526
]
528527
},
@@ -748,7 +747,10 @@
748747
}
749748
],
750749
"source": [
751-
"plt.plot((trace_pymc.warmup_sample_stats.n_steps).isel(draw=slice(0, 1000)).cumsum(\"draw\").T, np.log(trace_pymc.warmup_sample_stats.energy.isel(draw=slice(0, 1000)).T));\n",
750+
"plt.plot(\n",
751+
" (trace_pymc.warmup_sample_stats.n_steps).isel(draw=slice(0, 1000)).cumsum(\"draw\").T,\n",
752+
" np.log(trace_pymc.warmup_sample_stats.energy.isel(draw=slice(0, 1000)).T),\n",
753+
")\n",
752754
"plt.xlim(0, 10000)\n",
753755
"plt.ylabel(\"log-energy\")\n",
754756
"plt.xlabel(\"gradient evaluations\");"
@@ -782,7 +784,13 @@
782784
}
783785
],
784786
"source": [
785-
"plt.plot((trace_cmdstan.warmup_sample_stats.n_steps).isel(draw=slice(0, 1000)).cumsum(\"draw\").T, np.log(trace_cmdstan.warmup_sample_stats.energy.isel(draw=slice(0, 1000)).T));\n",
787+
"plt.plot(\n",
788+
" (trace_cmdstan.warmup_sample_stats.n_steps)\n",
789+
" .isel(draw=slice(0, 1000))\n",
790+
" .cumsum(\"draw\")\n",
791+
" .T,\n",
792+
" np.log(trace_cmdstan.warmup_sample_stats.energy.isel(draw=slice(0, 1000)).T),\n",
793+
")\n",
786794
"plt.xlim(0, 10000)\n",
787795
"plt.ylabel(\"log-energy\")\n",
788796
"plt.xlabel(\"gradient evaluations\");"
@@ -1607,7 +1615,12 @@
16071615
}
16081616
],
16091617
"source": [
1610-
"type({name: int(val) if isinstance(val, int) else list(val) for name, val in data_stan.items()}[\"county_idx\"][0])"
1618+
"type(\n",
1619+
" {\n",
1620+
" name: int(val) if isinstance(val, int) else list(val)\n",
1621+
" for name, val in data_stan.items()\n",
1622+
" }[\"county_idx\"][0]\n",
1623+
")"
16111624
]
16121625
},
16131626
{
@@ -1622,13 +1635,13 @@
16221635
" if isinstance(val, int):\n",
16231636
" data_json[name] = int(val)\n",
16241637
" continue\n",
1625-
" \n",
1638+
"\n",
16261639
" if val.dtype == np.int64:\n",
16271640
" data_json[name] = list(int(x) for x in val)\n",
16281641
" continue\n",
1629-
" \n",
1642+
"\n",
16301643
" data_json[name] = list(val)\n",
1631-
" \n",
1644+
"\n",
16321645
"with open(\"radon.json\", \"w\") as file:\n",
16331646
" json.dump(data_json, file)"
16341647
]

python/nutpie/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
from nutpie.sample import sample
55

66
__version__: str = _lib.__version__
7-
__all__ = ["__version__", "sample", "compile_pymc_model", "compile_stan_model"]
7+
__all__ = ["__version__", "compile_pymc_model", "compile_stan_model", "sample"]

0 commit comments

Comments
 (0)