@@ -221,6 +221,94 @@ def all_continuous(vars):
221
221
return True
222
222
223
223
224
+ def _sample_external_nuts (
225
+ sampler : str ,
226
+ draws : int ,
227
+ tune : int ,
228
+ chains : int ,
229
+ target_accept : float ,
230
+ random_seed : Union [RandomState , None ],
231
+ initvals : Union [StartDict , Sequence [Optional [StartDict ]], None ],
232
+ model : Model ,
233
+ progressbar : bool ,
234
+ idata_kwargs : Optional [Dict ],
235
+ ** kwargs ,
236
+ ):
237
+ warnings .warn ("Use of external NUTS sampler is still experimental" , UserWarning )
238
+
239
+ if sampler == "nutpie" :
240
+ try :
241
+ import nutpie
242
+ except ImportError as err :
243
+ raise ImportError (
244
+ "nutpie not found. Install it with conda install -c conda-forge nutpie"
245
+ ) from err
246
+
247
+ if initvals is not None :
248
+ warnings .warn (
249
+ "`initvals` are currently not passed to nutpie sampler. "
250
+ "Use `init_mean` kwarg following nutpie specification instead." ,
251
+ UserWarning ,
252
+ )
253
+
254
+ if idata_kwargs is not None :
255
+ warnings .warn (
256
+ "`idata_kwargs` are currently ignored by the nutpie sampler" ,
257
+ UserWarning ,
258
+ )
259
+
260
+ compiled_model = nutpie .compile_pymc_model (model )
261
+ idata = nutpie .sample (
262
+ compiled_model ,
263
+ draws = draws ,
264
+ tune = tune ,
265
+ chains = chains ,
266
+ target_accept = target_accept ,
267
+ seed = _get_seeds_per_chain (random_seed , 1 )[0 ],
268
+ progress_bar = progressbar ,
269
+ ** kwargs ,
270
+ )
271
+ return idata
272
+
273
+ elif sampler == "numpyro" :
274
+ import pymc .sampling .jax as pymc_jax
275
+
276
+ idata = pymc_jax .sample_numpyro_nuts (
277
+ draws = draws ,
278
+ tune = tune ,
279
+ chains = chains ,
280
+ target_accept = target_accept ,
281
+ random_seed = random_seed ,
282
+ initvals = initvals ,
283
+ model = model ,
284
+ progressbar = progressbar ,
285
+ idata_kwargs = idata_kwargs ,
286
+ ** kwargs ,
287
+ )
288
+ return idata
289
+
290
+ elif sampler == "blackjax" :
291
+ import pymc .sampling .jax as pymc_jax
292
+
293
+ idata = pymc_jax .sample_blackjax_nuts (
294
+ draws = draws ,
295
+ tune = tune ,
296
+ chains = chains ,
297
+ target_accept = target_accept ,
298
+ random_seed = random_seed ,
299
+ initvals = initvals ,
300
+ model = model ,
301
+ idata_kwargs = idata_kwargs ,
302
+ ** kwargs ,
303
+ )
304
+ return idata
305
+
306
+ else :
307
+ raise ValueError (
308
+ f"Sampler { sampler } not found. Choose one of ['nutpie', 'numpyro', 'blackjax', 'pymc']."
309
+ )
310
+
311
+
224
312
def sample (
225
313
draws : int = 1000 ,
226
314
step = None ,
@@ -239,6 +327,7 @@ def sample(
239
327
callback = None ,
240
328
jitter_max_retries : int = 10 ,
241
329
* ,
330
+ nuts_sampler : str = "pymc" ,
242
331
return_inferencedata : bool = True ,
243
332
keep_warning_stat : bool = False ,
244
333
idata_kwargs : dict = None ,
@@ -257,6 +346,7 @@ def sample(
257
346
init : str
258
347
Initialization method to use for auto-assigned NUTS samplers. See `pm.init_nuts` for a list
259
348
of all options. This argument is ignored when manually passing the NUTS step method.
349
+ Only applicable to the pymc nuts sampler.
260
350
step : function or iterable of functions
261
351
A step function or collection of functions. If there are variables without step methods,
262
352
step methods for those variables will be assigned automatically. By default the NUTS step
@@ -306,6 +396,10 @@ def sample(
306
396
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform
307
397
jitter that yields a finite probability. This applies to ``jitter+adapt_diag`` and
308
398
``jitter+adapt_full`` init methods.
399
+ nuts_sampler : str
400
+ Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
401
+ This requires the chosen sampler to be installed.
402
+ All samplers, except "pymc", require the full model to be continuous.
309
403
return_inferencedata : bool
310
404
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a
311
405
`MultiTrace` (False). Defaults to `True`.
@@ -401,7 +495,7 @@ def sample(
401
495
if "nuts" in kwargs :
402
496
kwargs ["nuts" ]["target_accept" ] = kwargs .pop ("target_accept" )
403
497
else :
404
- kwargs = { "nuts" : {"target_accept" : kwargs .pop ("target_accept" )} }
498
+ kwargs [ "nuts" ] = {"target_accept" : kwargs .pop ("target_accept" )}
405
499
if isinstance (trace , list ):
406
500
raise DeprecationWarning (
407
501
"We have removed support for partial traces because it simplified things."
@@ -441,8 +535,6 @@ def sample(
441
535
msg = "Only %s samples in chain." % draws
442
536
_log .warning (msg )
443
537
444
- draws += tune
445
-
446
538
auto_nuts_init = True
447
539
if step is not None :
448
540
if isinstance (step , CompoundStep ):
@@ -455,6 +547,25 @@ def sample(
455
547
initial_points = None
456
548
step = assign_step_methods (model , step , methods = pm .STEP_METHODS , step_kwargs = kwargs )
457
549
550
+ if nuts_sampler != "pymc" :
551
+ if not isinstance (step , NUTS ):
552
+ raise ValueError (
553
+ "Model can not be sampled with NUTS alone. Your model is probably not continuous."
554
+ )
555
+ return _sample_external_nuts (
556
+ sampler = nuts_sampler ,
557
+ draws = draws ,
558
+ tune = tune ,
559
+ chains = chains ,
560
+ target_accept = kwargs .pop ("nuts" , {}).get ("target_accept" , 0.8 ),
561
+ random_seed = random_seed ,
562
+ initvals = initvals ,
563
+ model = model ,
564
+ progressbar = progressbar ,
565
+ idata_kwargs = idata_kwargs ,
566
+ ** kwargs ,
567
+ )
568
+
458
569
if isinstance (step , list ):
459
570
step = CompoundStep (step )
460
571
elif isinstance (step , NUTS ) and auto_nuts_init :
@@ -503,7 +614,7 @@ def sample(
503
614
)
504
615
505
616
sample_args = {
506
- "draws" : draws ,
617
+ "draws" : draws + tune , # FIXME: Why is tune added to draws?
507
618
"step" : step ,
508
619
"start" : initial_points ,
509
620
"traces" : traces ,
0 commit comments