37
37
from pytensor .tensor .random .type import RandomType
38
38
39
39
from pymc import Model , modelcontext
40
- from pymc .backends .arviz import find_constants , find_observations
40
+ from pymc .backends .arviz import (
41
+ coords_and_dims_for_inferencedata ,
42
+ find_constants ,
43
+ find_observations ,
44
+ )
41
45
from pymc .distributions .multivariate import PosDefMatrix
42
46
from pymc .initial_point import StartDict
43
47
from pymc .logprob .utils import CheckParameterValue
@@ -392,17 +396,6 @@ def sample_blackjax_nuts(
392
396
393
397
vars_to_sample = list (get_default_varnames (var_names , include_transformed = keep_untransformed ))
394
398
395
- coords = {
396
- cname : np .array (cvals ) if isinstance (cvals , tuple ) else cvals
397
- for cname , cvals in model .coords .items ()
398
- if cvals is not None
399
- }
400
-
401
- dims = {
402
- var_name : [dim for dim in dims if dim is not None ]
403
- for var_name , dims in model .named_vars_to_dims .items ()
404
- }
405
-
406
399
(random_seed ,) = _get_seeds_per_chain (random_seed , 1 )
407
400
408
401
tic1 = datetime .now ()
@@ -485,7 +478,7 @@ def sample_blackjax_nuts(
485
478
"sampling_time" : (tic3 - tic2 ).total_seconds (),
486
479
}
487
480
488
- posterior = mcmc_samples
481
+ coords , dims = coords_and_dims_for_inferencedata ( model )
489
482
# Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs'
490
483
# and drop keys 'coords' and 'dims' from 'idata_kwargs' if present.
491
484
_update_coords_and_dims (coords = coords , dims = dims , idata_kwargs = idata_kwargs )
@@ -500,7 +493,7 @@ def sample_blackjax_nuts(
500
493
dims = dims ,
501
494
attrs = make_attrs (attrs , library = blackjax ),
502
495
)
503
- az_trace = to_trace (posterior = posterior , ** idata_kwargs )
496
+ az_trace = to_trace (posterior = mcmc_samples , ** idata_kwargs )
504
497
505
498
return az_trace
506
499
@@ -613,17 +606,6 @@ def sample_numpyro_nuts(
613
606
614
607
vars_to_sample = list (get_default_varnames (var_names , include_transformed = keep_untransformed ))
615
608
616
- coords = {
617
- cname : np .array (cvals ) if isinstance (cvals , tuple ) else cvals
618
- for cname , cvals in model .coords .items ()
619
- if cvals is not None
620
- }
621
-
622
- dims = {
623
- var_name : [dim for dim in dims if dim is not None ]
624
- for var_name , dims in model .named_vars_to_dims .items ()
625
- }
626
-
627
609
(random_seed ,) = _get_seeds_per_chain (random_seed , 1 )
628
610
629
611
tic1 = datetime .now ()
@@ -715,7 +697,7 @@ def sample_numpyro_nuts(
715
697
"sampling_time" : (tic3 - tic2 ).total_seconds (),
716
698
}
717
699
718
- posterior = mcmc_samples
700
+ coords , dims = coords_and_dims_for_inferencedata ( model )
719
701
# Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs'
720
702
# and drop keys 'coords' and 'dims' from 'idata_kwargs' if present.
721
703
_update_coords_and_dims (coords = coords , dims = dims , idata_kwargs = idata_kwargs )
@@ -730,5 +712,5 @@ def sample_numpyro_nuts(
730
712
dims = dims ,
731
713
attrs = make_attrs (attrs , library = numpyro ),
732
714
)
733
- az_trace = to_trace (posterior = posterior , ** idata_kwargs )
715
+ az_trace = to_trace (posterior = mcmc_samples , ** idata_kwargs )
734
716
return az_trace
0 commit comments