Skip to content

Commit 6ad9693

Browse files
Require explicit kwargs in internal sampling functions
1 parent 434333f commit 6ad9693

File tree

2 files changed

+39
-11
lines changed

2 files changed

+39
-11
lines changed

pymc/sampling/mcmc.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import warnings
2222

2323
from collections import defaultdict
24-
from copy import copy
2524
from typing import Iterator, List, Optional, Sequence, Tuple, Union
2625

2726
import numpy as np
@@ -654,6 +653,7 @@ def _check_start_shape(model, start: PointType):
654653

655654

656655
def _sample_many(
656+
*,
657657
draws: int,
658658
chains: int,
659659
start: Sequence[PointType],
@@ -754,10 +754,16 @@ def _sample(
754754
"""
755755
skip_first = kwargs.get("skip_first", 0)
756756

757-
trace = copy(trace)
758-
759757
sampling_gen = _iter_sample(
760-
draws, step, start, trace, chain, tune, model, random_seed, callback
758+
draws=draws,
759+
step=step,
760+
start=start,
761+
trace=trace,
762+
chain=chain,
763+
tune=tune,
764+
model=model,
765+
random_seed=random_seed,
766+
callback=callback,
761767
)
762768
_pbar_data = {"chain": chain, "divergences": 0}
763769
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
@@ -832,12 +838,23 @@ def iter_sample(
832838
for trace in iter_sample(500, step):
833839
...
834840
"""
835-
sampling = _iter_sample(draws, step, start, trace, chain, tune, model, random_seed, callback)
841+
sampling = _iter_sample(
842+
draws=draws,
843+
step=step,
844+
start=start,
845+
trace=trace,
846+
chain=chain,
847+
tune=tune,
848+
model=model,
849+
random_seed=random_seed,
850+
callback=callback,
851+
)
836852
for i, (strace, _) in enumerate(sampling):
837853
yield MultiTrace([strace[: i + 1]])
838854

839855

840856
def _iter_sample(
857+
*,
841858
draws: int,
842859
step,
843860
start: PointType,
@@ -934,6 +951,7 @@ def _iter_sample(
934951

935952

936953
def _mp_sample(
954+
*,
937955
draws: int,
938956
tune: int,
939957
step,

pymc/sampling/population.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949

5050
def _sample_population(
51+
*,
5152
draws: int,
5253
chains: int,
5354
start: Sequence[PointType],
@@ -86,10 +87,10 @@ def _sample_population(
8687
Contains samples of all chains
8788
"""
8889
sampling = _prepare_iter_population(
89-
draws,
90-
step,
91-
start,
92-
parallelize,
90+
draws=draws,
91+
step=step,
92+
start=start,
93+
parallelize=parallelize,
9394
tune=tune,
9495
model=model,
9596
random_seed=random_seed,
@@ -259,6 +260,7 @@ def step(self, tune_stop: bool, population) -> List[Tuple[PointType, StatsType]]
259260

260261

261262
def _prepare_iter_population(
263+
*,
262264
draws: int,
263265
step,
264266
start: Sequence[PointType],
@@ -344,11 +346,19 @@ def _prepare_iter_population(
344346

345347
# Because the preparations above are expensive, the actual iterator is
346348
# in another method. This way the progbar will not be disturbed.
347-
return _iter_population(draws, tune, popstep, steppers, traces, population)
349+
return _iter_population(
350+
draws=draws, tune=tune, popstep=popstep, steppers=steppers, traces=traces, points=population
351+
)
348352

349353

350354
def _iter_population(
351-
draws: int, tune: int, popstep: PopulationStepper, steppers, traces: Sequence[BaseTrace], points
355+
*,
356+
draws: int,
357+
tune: int,
358+
popstep: PopulationStepper,
359+
steppers,
360+
traces: Sequence[BaseTrace],
361+
points,
352362
) -> Iterator[Sequence[BaseTrace]]:
353363
"""Iterate a ``PopulationStepper``.
354364

0 commit comments

Comments
 (0)