Skip to content

Commit 26f7636

Browse files
Create traces with init_traces helper function
1 parent c5aff17 commit 26f7636

File tree

2 files changed

+39
-13
lines changed

2 files changed

+39
-13
lines changed

pymc/backends/__init__.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,15 @@
6161
6262
"""
6363
from copy import copy
64-
from typing import Dict, List, Optional
64+
from typing import Dict, List, Optional, Sequence, Union
65+
66+
import numpy as np
6567

6668
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
67-
from pymc.backends.base import BaseTrace
69+
from pymc.backends.base import BaseTrace, IBaseTrace
6870
from pymc.backends.ndarray import NDArray
6971
from pymc.model import Model
72+
from pymc.step_methods.compound import BlockedStep, CompoundStep
7073

7174
__all__ = ["to_inference_data", "predictions_to_inference_data"]
7275

@@ -92,3 +95,26 @@ def _init_trace(
9295

9396
strace.setup(expected_length, chain_number, stats_dtypes)
9497
return strace
98+
99+
100+
def init_traces(
101+
*,
102+
backend: Optional[BaseTrace],
103+
chains: int,
104+
expected_length: int,
105+
step: Union[BlockedStep, CompoundStep],
106+
var_dtypes: Dict[str, np.dtype],
107+
var_shapes: Dict[str, Sequence[int]],
108+
model: Model,
109+
) -> Sequence[IBaseTrace]:
110+
"""Initializes a trace recorder for each chain."""
111+
return [
112+
_init_trace(
113+
expected_length=expected_length,
114+
stats_dtypes=step.stats_dtypes,
115+
chain_number=chain_number,
116+
trace=backend,
117+
model=model,
118+
)
119+
for chain_number in range(chains)
120+
]

pymc/sampling/mcmc.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
import pymc as pm
3434

35-
from pymc.backends import _init_trace
35+
from pymc.backends import init_traces
3636
from pymc.backends.base import BaseTrace, IBaseTrace, MultiTrace, _choose_chains
3737
from pymc.blocking import DictToArrayBijection
3838
from pymc.exceptions import SamplingError
@@ -486,21 +486,21 @@ def sample(
486486
initial_points = [ipfn(seed) for ipfn, seed in zip(ipfns, random_seed_list)]
487487

488488
# One final check that shapes and logps at the starting points are okay.
489+
ip: Dict[str, np.ndarray]
489490
for ip in initial_points:
490491
model.check_start_vals(ip)
491492
_check_start_shape(model, ip)
492493

493494
# Create trace backends for each chain
494-
traces = [
495-
_init_trace(
496-
expected_length=draws + tune,
497-
stats_dtypes=step.stats_dtypes,
498-
chain_number=chain_number,
499-
trace=trace,
500-
model=model,
501-
)
502-
for chain_number in range(chains)
503-
]
495+
traces = init_traces(
496+
backend=trace,
497+
chains=chains,
498+
expected_length=draws + tune,
499+
step=step,
500+
var_dtypes={vn: v.dtype for vn, v in ip.items()},
501+
var_shapes={vn: v.shape for vn, v in ip.items()},
502+
model=model,
503+
)
504504

505505
sample_args = {
506506
"draws": draws,

0 commit comments

Comments
 (0)