|
41 | 41 | from pymc.sampling.parallel import Draw, _cpu_count
|
42 | 42 | from pymc.sampling.population import _sample_population
|
43 | 43 | from pymc.stats.convergence import log_warning_stats, run_convergence_checks
|
44 |
| -from pymc.step_methods import NUTS, CompoundStep, DEMetropolis |
| 44 | +from pymc.step_methods import NUTS, CompoundStep |
45 | 45 | from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
|
46 | 46 | from pymc.step_methods.hmc import quadpotential
|
47 | 47 | from pymc.util import (
|
@@ -538,32 +538,11 @@ def sample(
|
538 | 538 | parallel = False
|
539 | 539 | if not parallel:
|
540 | 540 | if has_population_samplers:
|
541 |
| - has_demcmc = np.any( |
542 |
| - [ |
543 |
| - isinstance(m, DEMetropolis) |
544 |
| - for m in (step.methods if isinstance(step, CompoundStep) else [step]) |
545 |
| - ] |
546 |
| - ) |
547 | 541 | _log.info(f"Population sampling ({chains} chains)")
|
548 |
| - |
549 |
| - initial_point_model_size = sum(initial_points[0][n.name].size for n in model.value_vars) |
550 |
| - |
551 |
| - if has_demcmc and chains < 3: |
552 |
| - raise ValueError( |
553 |
| - "DEMetropolis requires at least 3 chains. " |
554 |
| - "For this {}-dimensional model you should use ≥{} chains".format( |
555 |
| - initial_point_model_size, initial_point_model_size + 1 |
556 |
| - ) |
557 |
| - ) |
558 |
| - if has_demcmc and chains <= initial_point_model_size: |
559 |
| - warnings.warn( |
560 |
| - "DEMetropolis should be used with more chains than dimensions! " |
561 |
| - "(The model has {} dimensions.)".format(initial_point_model_size), |
562 |
| - UserWarning, |
563 |
| - stacklevel=2, |
564 |
| - ) |
565 | 542 | _print_step_hierarchy(step)
|
566 |
| - mtrace = _sample_population(parallelize=cores > 1, **sample_args) |
| 543 | + mtrace = _sample_population( |
| 544 | + initial_points=initial_points, parallelize=cores > 1, **sample_args |
| 545 | + ) |
567 | 546 | else:
|
568 | 547 | _log.info(f"Sequential sampling ({chains} chains in 1 job)")
|
569 | 548 | _print_step_hierarchy(step)
|
|
0 commit comments