Skip to content

Commit b64f810

Browse files
authored
Merge pull request #3327 from pymc-devs/merge_kwargs
WIP: Merge nuts_kwargs and step_kwargs into kwargs
2 parents 2b5dd34 + 4a53c7e commit b64f810

14 files changed

+59
-55
lines changed

RELEASE-NOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
### Deprecations
1515

16+
- `nuts_kwargs` and `step_kwargs` have been deprecated in favor of using the standard `kwargs` to pass optional step method arguments.
17+
1618
## PyMC3 3.6 (Dec 21 2018)
1719

1820
This will be the last release to support Python 2.

docs/source/notebooks/Diagnosing_biased_Inference_with_Divergences.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,7 +1002,7 @@
10021002
"source": [
10031003
"with Centered_eight:\n",
10041004
" fit_cp85 = pm.sample(5000, chains=2, tune=2000,\n",
1005-
" nuts_kwargs=dict(target_accept=.85))"
1005+
" target_accept=.85)"
10061006
]
10071007
},
10081008
{
@@ -1029,7 +1029,7 @@
10291029
"source": [
10301030
"with Centered_eight:\n",
10311031
" fit_cp90 = pm.sample(5000, chains=2, tune=2000,\n",
1032-
" nuts_kwargs=dict(target_accept=.90))"
1032+
" target_accept=.90)"
10331033
]
10341034
},
10351035
{
@@ -1056,7 +1056,7 @@
10561056
"source": [
10571057
"with Centered_eight:\n",
10581058
" fit_cp95 = pm.sample(5000, chains=2, tune=2000,\n",
1059-
" nuts_kwargs=dict(target_accept=.95))"
1059+
" target_accept=.95)"
10601060
]
10611061
},
10621062
{
@@ -1083,7 +1083,7 @@
10831083
"source": [
10841084
"with Centered_eight:\n",
10851085
" fit_cp99 = pm.sample(5000, chains=2, tune=2000,\n",
1086-
" nuts_kwargs=dict(target_accept=.99))"
1086+
" target_accept=.99)"
10871087
]
10881088
},
10891089
{
@@ -1350,7 +1350,7 @@
13501350
"source": [
13511351
"with NonCentered_eight:\n",
13521352
" fit_ncp80 = pm.sample(5000, chains=2, tune=1000, random_seed=SEED,\n",
1353-
" nuts_kwargs=dict(target_accept=.80))"
1353+
" target_accept=.80)"
13541354
]
13551355
},
13561356
{
@@ -1708,7 +1708,7 @@
17081708
"source": [
17091709
"with NonCentered_eight:\n",
17101710
" fit_ncp90 = pm.sample(5000, chains=2, tune=1000, random_seed=SEED,\n",
1711-
" nuts_kwargs=dict(target_accept=.90))\n",
1711+
" target_accept=.90)\n",
17121712
" \n",
17131713
"# display the total number and percentage of divergent\n",
17141714
"divergent = fit_ncp90['diverging']\n",

docs/source/notebooks/GLM-hierarchical-binominal-model.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@
309309
" theta = pm.Beta('theta', alpha=ab[0], beta=ab[1], shape=N)\n",
310310
"\n",
311311
" p = pm.Binomial('y', p=theta, observed=y, n=n)\n",
312-
" trace = pm.sample(1000, tune=2000, nuts_kwargs={'target_accept': .95})\n",
312+
" trace = pm.sample(1000, tune=2000, target_accept=0.95)\n",
313313
" "
314314
]
315315
},

docs/source/notebooks/GLM-rolling-regression.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@
328328
"source": [
329329
"with model_randomwalk:\n",
330330
" trace_rw = pm.sample(tune=2000, cores=4, samples=200, \n",
331-
" nuts_kwargs=dict(target_accept=.9))"
331+
" target_accept=0.9)"
332332
]
333333
},
334334
{

docs/source/notebooks/GP-MaunaLoa2.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@
260260
],
261261
"source": [
262262
"with model:\n",
263-
" tr = pm.sample(1000, tune=1000, chains=2, cores=1, nuts_kwargs={\"target_accept\":0.95})"
263+
" tr = pm.sample(1000, tune=1000, chains=2, cores=1, target_accept=0.95)"
264264
]
265265
},
266266
{
@@ -595,7 +595,7 @@
595595
],
596596
"source": [
597597
"with model:\n",
598-
" tr = pm.sample(1000, tune=1000, chains=2, cores=1, nuts_kwargs={\"target_accept\":0.95})"
598+
" tr = pm.sample(1000, tune=1000, chains=2, cores=1, target_accept=0.95)"
599599
]
600600
},
601601
{
@@ -1084,7 +1084,7 @@
10841084
],
10851085
"source": [
10861086
"with model:\n",
1087-
" tr = pm.sample(500, chains=2, cores=1, nuts_kwargs={\"target_accept\": 0.95})"
1087+
" tr = pm.sample(500, chains=2, cores=1, target_accept=0.95)"
10881088
]
10891089
},
10901090
{

docs/source/notebooks/PyMC3_tips_and_heuristic.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@
484484
" # Proportion sptial variance\n",
485485
" alpha = pm.Deterministic('alpha', sd_c/(sd_h+sd_c))\n",
486486
"\n",
487-
" trace1 = pm.sample(3e3, cores=2, tune=1000, nuts_kwargs={'max_treedepth': 15})"
487+
" trace1 = pm.sample(3e3, cores=2, tune=1000, max_treedepth=15)"
488488
]
489489
},
490490
{
@@ -702,7 +702,7 @@
702702
" # Proportion sptial variance\n",
703703
" alpha = pm.Deterministic('alpha', sd_c/(sd_h+sd_c))\n",
704704
"\n",
705-
" trace2 = pm.sample(3e3, cores=2, tune=1000, nuts_kwargs={'max_treedepth': 15})"
705+
" trace2 = pm.sample(3e3, cores=2, tune=1000, max_treedepth=15)"
706706
]
707707
},
708708
{

docs/source/notebooks/hierarchical_partial_pooling.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@
171171
"source": [
172172
"with baseball_model:\n",
173173
" trace = pm.sample(2000, tune=1000, chains=2,\n",
174-
" nuts_kwargs={'target_accept': 0.95})"
174+
" target_accept=0.95)"
175175
]
176176
},
177177
{

docs/source/notebooks/stochastic_volatility.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@
171171
],
172172
"source": [
173173
"with model:\n",
174-
" trace = pm.sample(tune=2000, nuts_kwargs=dict(target_accept=.9))"
174+
" trace = pm.sample(tune=2000, target_accept=0.9)"
175175
]
176176
},
177177
{

docs/source/notebooks/weibull_aft.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@
184184
"with model_1:\n",
185185
" # Increase tune and change init to avoid divergences\n",
186186
" trace_1 = pm.sample(draws=1000, tune=1000,\n",
187-
" nuts_kwargs={'target_accept': 0.9},\n",
187+
" target_accept=0.9,\n",
188188
" init='adapt_diag')"
189189
]
190190
},
@@ -337,7 +337,7 @@
337337
"with model_2:\n",
338338
" # Increase tune and target_accept to avoid divergences\n",
339339
" trace_2 = pm.sample(draws=1000, tune=1000,\n",
340-
" nuts_kwargs={'target_accept': 0.9})"
340+
" target_accept=0.9)"
341341
]
342342
},
343343
{

pymc3/examples/arma_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def run(n_samples=1000):
7878
with model:
7979
trace = pm.sample(draws=n_samples,
8080
tune=1000,
81-
nuts_kwargs=dict(target_accept=.99))
81+
target_accept=.99)
8282

8383
pm.plots.traceplot(trace)
8484
pm.plots.forestplot(trace)

0 commit comments

Comments
 (0)