Skip to content

Commit 0ab1596

Browse files
committed
Resume sampling from existing ZarrTrace
1 parent e81cd06 commit 0ab1596

File tree

5 files changed

+347
-29
lines changed

5 files changed

+347
-29
lines changed

pymc/backends/__init__.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
7373
from pymc.backends.base import BaseTrace, IBaseTrace
7474
from pymc.backends.ndarray import NDArray
75-
from pymc.backends.zarr import ZarrTrace
75+
from pymc.backends.zarr import TraceAlreadyInitialized, ZarrTrace
7676
from pymc.blocking import PointType
7777
from pymc.model import Model
7878
from pymc.step_methods.compound import BlockedStep, CompoundStep
@@ -132,15 +132,41 @@ def init_traces(
132132
) -> tuple[RunType | None, Sequence[IBaseTrace]]:
133133
"""Initialize a trace recorder for each chain."""
134134
if isinstance(backend, ZarrTrace):
135-
backend.init_trace(
136-
chains=chains,
137-
draws=expected_length - tune,
138-
tune=tune,
139-
step=step,
140-
model=model,
141-
vars=trace_vars,
142-
test_point=initial_point,
143-
)
135+
try:
136+
backend.init_trace(
137+
chains=chains,
138+
draws=expected_length - tune,
139+
tune=tune,
140+
step=step,
141+
model=model,
142+
vars=trace_vars,
143+
test_point=initial_point,
144+
)
145+
except TraceAlreadyInitialized:
146+
# Trace has already been initialized. We need to make sure that the
147+
# tracked variable names and the number of chains match, and then resize
148+
# the zarr groups to the desired number of draws and tune.
149+
backend.assert_model_and_step_are_compatible(
150+
step=step,
151+
model=model,
152+
vars=trace_vars,
153+
)
154+
assert backend.posterior.chain.size == chains, (
155+
f"The requested number of chains {chains} does not match the number "
156+
f"of chains stored in the trace ({backend.posterior.chain.size})."
157+
)
158+
vars, var_names = backend.parse_varnames(model=model, vars=trace_vars)
159+
backend.link_model_and_step(
160+
chains=chains,
161+
draws=expected_length - tune,
162+
tune=tune,
163+
step=step,
164+
model=model,
165+
vars=vars,
166+
var_names=var_names,
167+
test_point=initial_point,
168+
)
169+
backend.resize(tune=tune, draws=expected_length - tune)
144170
return None, backend.straces
145171
if HAS_MCB and isinstance(backend, Backend):
146172
return init_chain_adapters(

pymc/sampling/mcmc.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -993,11 +993,8 @@ def _sample_return(
993993
Final step of `pm.sampler`.
994994
"""
995995
if isinstance(traces, ZarrTrace):
996-
# Split warmup from posterior samples
997-
traces.split_warmup_groups()
998-
999996
# Set sampling time
1000-
traces.sampling_time = t_sampling
997+
traces.sampling_time = traces.sampling_time + t_sampling
1001998

1002999
# Compute number of actual draws per chain
10031000
total_draws_per_chain = traces._sampling_state.draw_idx[:]
@@ -1226,7 +1223,7 @@ def _sample(
12261223
callback=callback,
12271224
)
12281225
try:
1229-
for it, stats in enumerate(sampling_gen):
1226+
for it, stats in sampling_gen:
12301227
progress_manager.update(
12311228
chain_idx=chain, is_last=False, draw=it, stats=stats, tuning=it > tune
12321229
)
@@ -1251,7 +1248,7 @@ def _iter_sample(
12511248
rng: np.random.Generator,
12521249
model: Model | None = None,
12531250
callback: SamplingIteratorCallback | None = None,
1254-
) -> Iterator[list[dict[str, Any]]]:
1251+
) -> Iterator[tuple[int, list[dict[str, Any]]]]:
12551252
"""Sample one chain with a generator (singleprocess).
12561253
12571254
Parameters
@@ -1285,14 +1282,33 @@ def _iter_sample(
12851282
step.set_rng(rng)
12861283

12871284
point = start
1285+
initial_draw_idx = 0
1286+
step.tune = bool(tune)
1287+
if hasattr(step, "reset_tuning"):
1288+
step.reset_tuning()
12881289
if isinstance(trace, ZarrChain):
12891290
trace.link_stepper(step)
1291+
stored_draw_idx = trace._sampling_state.draw_idx[chain]
1292+
stored_sampling_state = trace._sampling_state.sampling_state[chain]
1293+
if stored_draw_idx > 0:
1294+
if stored_sampling_state is not None:
1295+
step.sampling_state = stored_sampling_state
1296+
else:
1297+
raise RuntimeError(
1298+
"Cannot use the supplied ZarrTrace to restart sampling because "
1299+
"it has no sampling_state information stored. You will have to "
1300+
"resample from scratch."
1301+
)
1302+
initial_draw_idx = stored_draw_idx
1303+
point = trace.get_mcmc_point()
1304+
else:
1305+
# Store initial point in trace
1306+
trace.set_mcmc_point(point)
12901307

12911308
try:
1292-
step.tune = bool(tune)
1293-
if hasattr(step, "reset_tuning"):
1294-
step.reset_tuning()
1295-
for i in range(draws):
1309+
for i in range(initial_draw_idx, draws):
1310+
diverging = False
1311+
12961312
if i == 0 and hasattr(step, "iter_count"):
12971313
step.iter_count = 0
12981314
if i == tune:
@@ -1308,7 +1324,7 @@ def _iter_sample(
13081324
draw=Draw(chain, i == draws, i, i < tune, stats, point),
13091325
)
13101326

1311-
yield stats
1327+
yield i, stats
13121328

13131329
except (KeyboardInterrupt, BaseException):
13141330
if isinstance(trace, ZarrChain):

pymc/sampling/parallel.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,24 @@ def _start_loop(self):
194194

195195
draw = 0
196196
tuning = True
197+
if self._zarr_recording:
198+
trace = self._zarr_chain
199+
stored_draw_idx = trace._sampling_state.draw_idx[self.chain]
200+
stored_sampling_state = trace._sampling_state.sampling_state[self.chain]
201+
if stored_draw_idx > 0:
202+
if stored_sampling_state is not None:
203+
self._step_method.sampling_state = stored_sampling_state
204+
else:
205+
raise RuntimeError(
206+
"Cannot use the supplied ZarrTrace to restart sampling because "
207+
"it has no sampling_state information stored. You will have to "
208+
"resample from scratch."
209+
)
210+
draw = stored_draw_idx
211+
self._write_point(trace.get_mcmc_point())
212+
else:
213+
# Store starting point in trace's mcmc_point
214+
trace.set_mcmc_point(self._point)
197215

198216
msg = self._recv_msg()
199217
if msg[0] == "abort":

pymc/sampling/population.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from collections.abc import Iterator, Sequence
2121
from copy import copy
22-
from typing import TypeAlias
22+
from typing import TypeAlias, cast
2323

2424
import cloudpickle
2525
import numpy as np
@@ -37,7 +37,7 @@
3737
PopulationArrayStepShared,
3838
StatsType,
3939
)
40-
from pymc.step_methods.compound import StepMethodState
40+
from pymc.step_methods.compound import CompoundStepState, StepMethodState
4141
from pymc.step_methods.metropolis import DEMetropolis
4242
from pymc.util import CustomProgress
4343

@@ -54,7 +54,7 @@ def _sample_population(
5454
*,
5555
initial_points: Sequence[PointType],
5656
draws: int,
57-
start: Sequence[PointType],
57+
start: list[PointType],
5858
rngs: Sequence[np.random.Generator],
5959
step: BlockedStep | CompoundStep,
6060
tune: int,
@@ -151,7 +151,9 @@ def warn_population_size(
151151
class PopulationStepper:
152152
"""Wraps population of step methods to step them in parallel with single or multiprocessing."""
153153

154-
def __init__(self, steppers, parallelize: bool, progressbar: bool = True):
154+
def __init__(
155+
self, steppers, parallelize: bool, progressbar: bool = True, first_draw_idx: int = 0
156+
):
155157
"""Use multiprocessing to parallelize chains.
156158
157159
Falls back to sequential evaluation if multiprocessing fails.
@@ -330,7 +332,7 @@ def _prepare_iter_population(
330332
*,
331333
draws: int,
332334
step,
333-
start: Sequence[PointType],
335+
start: list[PointType],
334336
parallelize: bool,
335337
traces: Sequence[BaseTrace],
336338
tune: int,
@@ -376,10 +378,35 @@ def _prepare_iter_population(
376378
raise ValueError("Argument `draws` should be above 0.")
377379

378380
# The initialization of traces, samplers and points must happen in the right order:
381+
# 0. previous sampling state is loaded if possible
379382
# 1. population of points is created
380383
# 2. steppers are initialized and linked to the points object
381384
# 3. a PopulationStepper is configured for parallelized stepping
382385

386+
# 0. load sampling state and start point from traces if possible
387+
first_draw_idx = 0
388+
stored_sampling_states: Sequence[StepMethodState | CompoundStepState] | None = None
389+
can_resume_sampling = False
390+
if isinstance(traces[0], ZarrChain):
391+
# All traces share the same store. This lets us load the past sampling states and draw
392+
# indices for all chain
393+
stored_draw_idxs = traces[0]._sampling_state.draw_idx[:]
394+
stored_sampling_states = cast(
395+
Sequence[StepMethodState | CompoundStepState],
396+
traces[0]._sampling_state.sampling_state[:],
397+
)
398+
can_resume_sampling = (
399+
all(stored_draw_idxs > 0)
400+
and all(stored_draw_idxs == stored_draw_idxs[0])
401+
and all(sampling_state is not None for sampling_state in stored_sampling_states)
402+
)
403+
for chain, trace in enumerate(traces):
404+
trace = cast(ZarrChain, trace)
405+
if can_resume_sampling:
406+
start[chain] = trace.get_mcmc_point()
407+
else:
408+
trace.set_mcmc_point(start[chain])
409+
383410
# 1. create a population (points) that tracks each chain
384411
# it is updated as the chains are advanced
385412
population = [start[c] for c in range(nchains)]
@@ -401,15 +428,25 @@ def _prepare_iter_population(
401428
for sm in chainstep.methods if isinstance(step, CompoundStep) else [chainstep]:
402429
if isinstance(sm, PopulationArrayStepShared):
403430
sm.link_population(population, c)
431+
if can_resume_sampling:
432+
chainstep.sampling_state = cast(Sequence[CompoundStepState], stored_sampling_states)[c]
404433
steppers.append(chainstep)
405434

406435
# 3. configure the PopulationStepper (expensive call)
407-
popstep = PopulationStepper(steppers, parallelize, progressbar=progressbar)
436+
popstep = PopulationStepper(
437+
steppers, parallelize, progressbar=progressbar, first_draw_idx=first_draw_idx
438+
)
408439

409440
# Because the preparations above are expensive, the actual iterator is
410441
# in another method. This way the progbar will not be disturbed.
411442
return _iter_population(
412-
draws=draws, tune=tune, popstep=popstep, steppers=steppers, traces=traces, points=population
443+
draws=draws,
444+
tune=tune,
445+
popstep=popstep,
446+
steppers=steppers,
447+
traces=traces,
448+
points=population,
449+
first_draw_idx=first_draw_idx,
413450
)
414451

415452

@@ -421,6 +458,7 @@ def _iter_population(
421458
steppers,
422459
traces: Sequence[BaseTrace],
423460
points,
461+
first_draw_idx=0,
424462
) -> Iterator[int]:
425463
"""Iterate a ``PopulationStepper``.
426464
@@ -450,7 +488,7 @@ def _iter_population(
450488
try:
451489
with popstep:
452490
# iterate draws of all chains
453-
for i in range(draws):
491+
for i in range(first_draw_idx, draws):
454492
# this call steps all chains and returns a list of (point, stats)
455493
# the `popstep` may interact with subprocesses internally
456494
updates = popstep.step(i == tune, points)

0 commit comments

Comments
 (0)