Skip to content

Commit 4c020e7

Browse files
Move population size warnings to population sampling submodule
1 parent 6ad9693 commit 4c020e7

File tree

2 files changed

+41
-25
lines changed

2 files changed

+41
-25
lines changed

pymc/sampling/mcmc.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from pymc.sampling.parallel import Draw, _cpu_count
4242
from pymc.sampling.population import _sample_population
4343
from pymc.stats.convergence import log_warning_stats, run_convergence_checks
44-
from pymc.step_methods import NUTS, CompoundStep, DEMetropolis
44+
from pymc.step_methods import NUTS, CompoundStep
4545
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
4646
from pymc.step_methods.hmc import quadpotential
4747
from pymc.util import (
@@ -538,32 +538,11 @@ def sample(
538538
parallel = False
539539
if not parallel:
540540
if has_population_samplers:
541-
has_demcmc = np.any(
542-
[
543-
isinstance(m, DEMetropolis)
544-
for m in (step.methods if isinstance(step, CompoundStep) else [step])
545-
]
546-
)
547541
_log.info(f"Population sampling ({chains} chains)")
548-
549-
initial_point_model_size = sum(initial_points[0][n.name].size for n in model.value_vars)
550-
551-
if has_demcmc and chains < 3:
552-
raise ValueError(
553-
"DEMetropolis requires at least 3 chains. "
554-
"For this {}-dimensional model you should use ≥{} chains".format(
555-
initial_point_model_size, initial_point_model_size + 1
556-
)
557-
)
558-
if has_demcmc and chains <= initial_point_model_size:
559-
warnings.warn(
560-
"DEMetropolis should be used with more chains than dimensions! "
561-
"(The model has {} dimensions.)".format(initial_point_model_size),
562-
UserWarning,
563-
stacklevel=2,
564-
)
565542
_print_step_hierarchy(step)
566-
mtrace = _sample_population(parallelize=cores > 1, **sample_args)
543+
mtrace = _sample_population(
544+
initial_points=initial_points, parallelize=cores > 1, **sample_args
545+
)
567546
else:
568547
_log.info(f"Sequential sampling ({chains} chains in 1 job)")
569548
_print_step_hierarchy(step)

pymc/sampling/population.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Specializes on running MCMCs with population step methods."""
1616

1717
import logging
18+
import warnings
1819

1920
from copy import copy
2021
from typing import Iterator, List, Sequence, Tuple, Union
@@ -36,6 +37,7 @@
3637
PopulationArrayStepShared,
3738
StatsType,
3839
)
40+
from pymc.step_methods.metropolis import DEMetropolis
3941
from pymc.util import RandomSeed
4042

4143
__all__ = ()
@@ -49,6 +51,7 @@
4951

5052
def _sample_population(
5153
*,
54+
initial_points,
5255
draws: int,
5356
chains: int,
5457
start: Sequence[PointType],
@@ -86,6 +89,13 @@ def _sample_population(
8689
trace : MultiTrace
8790
Contains samples of all chains
8891
"""
92+
warn_population_size(
93+
step=step,
94+
initial_points=initial_points,
95+
model=model,
96+
chains=chains,
97+
)
98+
8999
sampling = _prepare_iter_population(
90100
draws=draws,
91101
step=step,
@@ -106,6 +116,33 @@ def _sample_population(
106116
return MultiTrace(latest_traces)
107117

108118

119+
def warn_population_size(*, step: CompoundStep, initial_points, model, chains: int):
120+
has_demcmc = np.any(
121+
[
122+
isinstance(m, DEMetropolis)
123+
for m in (step.methods if isinstance(step, CompoundStep) else [step])
124+
]
125+
)
126+
127+
initial_point_model_size = sum(initial_points[0][n.name].size for n in model.value_vars)
128+
129+
if has_demcmc and chains < 3:
130+
raise ValueError(
131+
"DEMetropolis requires at least 3 chains. "
132+
"For this {}-dimensional model you should use ≥{} chains".format(
133+
initial_point_model_size, initial_point_model_size + 1
134+
)
135+
)
136+
if has_demcmc and chains <= initial_point_model_size:
137+
warnings.warn(
138+
"DEMetropolis should be used with more chains than dimensions! "
139+
"(The model has {} dimensions.)".format(initial_point_model_size),
140+
UserWarning,
141+
stacklevel=2,
142+
)
143+
return
144+
145+
109146
class PopulationStepper:
110147
"""Wraps population of step methods to step them in parallel with single or multiprocessing."""
111148

0 commit comments

Comments
 (0)