Skip to content

Commit 10c66b4

Browse files
Move Competence and BlockedStep to compound module
These types follow directly from the introduction of CompoundStep. Moving them decouples code that depends on them from the `arraystep` module.
1 parent 8547d6d commit 10c66b4

File tree

6 files changed

+114
-106
lines changed

6 files changed

+114
-106
lines changed

pymc/step_methods/arraystep.py

Lines changed: 4 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -12,116 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from abc import ABC, abstractmethod
16-
from enum import IntEnum, unique
17-
from typing import Callable, Dict, List, Tuple, Union, cast
15+
from abc import abstractmethod
16+
from typing import Callable, List, Tuple, Union, cast
1817

1918
import numpy as np
2019

2120
from numpy.random import uniform
22-
from pytensor.graph.basic import Variable
2321

2422
from pymc.blocking import DictToArrayBijection, PointType, RaveledVars, StatsType
2523
from pymc.model import modelcontext
26-
from pymc.step_methods.compound import CompoundStep
24+
from pymc.step_methods.compound import BlockedStep
2725
from pymc.util import get_var_name
2826

29-
__all__ = ["ArrayStep", "ArrayStepShared", "metrop_select", "Competence"]
30-
31-
32-
@unique
33-
class Competence(IntEnum):
34-
"""Enum for characterizing competence classes of step methods.
35-
Values include:
36-
0: INCOMPATIBLE
37-
1: COMPATIBLE
38-
2: PREFERRED
39-
3: IDEAL
40-
"""
41-
42-
INCOMPATIBLE = 0
43-
COMPATIBLE = 1
44-
PREFERRED = 2
45-
IDEAL = 3
46-
47-
48-
class BlockedStep(ABC):
49-
50-
stats_dtypes: List[Dict[str, type]] = []
51-
vars: List[Variable] = []
52-
53-
def __new__(cls, *args, **kwargs):
54-
blocked = kwargs.get("blocked")
55-
if blocked is None:
56-
# Try to look up default value from class
57-
blocked = getattr(cls, "default_blocked", True)
58-
kwargs["blocked"] = blocked
59-
60-
model = modelcontext(kwargs.get("model"))
61-
kwargs.update({"model": model})
62-
63-
# vars can either be first arg or a kwarg
64-
if "vars" not in kwargs and len(args) >= 1:
65-
vars = args[0]
66-
args = args[1:]
67-
elif "vars" in kwargs:
68-
vars = kwargs.pop("vars")
69-
else: # Assume all model variables
70-
vars = model.value_vars
71-
72-
if not isinstance(vars, (tuple, list)):
73-
vars = [vars]
74-
75-
if len(vars) == 0:
76-
raise ValueError("No free random variables to sample.")
77-
78-
if not blocked and len(vars) > 1:
79-
# In this case we create a separate sampler for each var
80-
# and append them to a CompoundStep
81-
steps = []
82-
for var in vars:
83-
step = super().__new__(cls)
84-
# If we don't return the instance we have to manually
85-
# call __init__
86-
step.__init__([var], *args, **kwargs)
87-
# Hack for creating the class correctly when unpickling.
88-
step.__newargs = ([var],) + args, kwargs
89-
steps.append(step)
90-
91-
return CompoundStep(steps)
92-
else:
93-
step = super().__new__(cls)
94-
# Hack for creating the class correctly when unpickling.
95-
step.__newargs = (vars,) + args, kwargs
96-
return step
97-
98-
# Hack for creating the class correctly when unpickling.
99-
def __getnewargs_ex__(self):
100-
return self.__newargs
101-
102-
@abstractmethod
103-
def step(self, point: PointType) -> Tuple[PointType, StatsType]:
104-
"""Perform a single step of the sampler."""
105-
106-
@staticmethod
107-
def competence(var, has_grad):
108-
return Competence.INCOMPATIBLE
109-
110-
@classmethod
111-
def _competence(cls, vars, have_grad):
112-
vars = np.atleast_1d(vars)
113-
have_grad = np.atleast_1d(have_grad)
114-
competences = []
115-
for var, has_grad in zip(vars, have_grad):
116-
try:
117-
competences.append(cls.competence(var, has_grad))
118-
except TypeError:
119-
competences.append(cls.competence(var))
120-
return competences
121-
122-
def stop_tuning(self):
123-
if hasattr(self, "tune"):
124-
self.tune = False
27+
__all__ = ["ArrayStep", "ArrayStepShared", "metrop_select"]
12528

12629

12730
class ArrayStep(BlockedStep):

pymc/step_methods/compound.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,113 @@
1919
"""
2020

2121

22-
from typing import Tuple
22+
from abc import ABC, abstractmethod
23+
from enum import IntEnum, unique
24+
from typing import Dict, List, Tuple
25+
26+
import numpy as np
27+
28+
from pytensor.graph.basic import Variable
2329

2430
from pymc.blocking import PointType, StatsType
31+
from pymc.model import modelcontext
32+
33+
__all__ = ("Competence", "CompoundStep")
34+
35+
36+
@unique
37+
class Competence(IntEnum):
38+
"""Enum for characterizing competence classes of step methods.
39+
Values include:
40+
0: INCOMPATIBLE
41+
1: COMPATIBLE
42+
2: PREFERRED
43+
3: IDEAL
44+
"""
45+
46+
INCOMPATIBLE = 0
47+
COMPATIBLE = 1
48+
PREFERRED = 2
49+
IDEAL = 3
50+
51+
52+
class BlockedStep(ABC):
53+
54+
stats_dtypes: List[Dict[str, type]] = []
55+
vars: List[Variable] = []
56+
57+
def __new__(cls, *args, **kwargs):
58+
blocked = kwargs.get("blocked")
59+
if blocked is None:
60+
# Try to look up default value from class
61+
blocked = getattr(cls, "default_blocked", True)
62+
kwargs["blocked"] = blocked
63+
64+
model = modelcontext(kwargs.get("model"))
65+
kwargs.update({"model": model})
66+
67+
# vars can either be first arg or a kwarg
68+
if "vars" not in kwargs and len(args) >= 1:
69+
vars = args[0]
70+
args = args[1:]
71+
elif "vars" in kwargs:
72+
vars = kwargs.pop("vars")
73+
else: # Assume all model variables
74+
vars = model.value_vars
75+
76+
if not isinstance(vars, (tuple, list)):
77+
vars = [vars]
78+
79+
if len(vars) == 0:
80+
raise ValueError("No free random variables to sample.")
81+
82+
if not blocked and len(vars) > 1:
83+
# In this case we create a separate sampler for each var
84+
# and append them to a CompoundStep
85+
steps = []
86+
for var in vars:
87+
step = super().__new__(cls)
88+
# If we don't return the instance we have to manually
89+
# call __init__
90+
step.__init__([var], *args, **kwargs)
91+
# Hack for creating the class correctly when unpickling.
92+
step.__newargs = ([var],) + args, kwargs
93+
steps.append(step)
94+
95+
return CompoundStep(steps)
96+
else:
97+
step = super().__new__(cls)
98+
# Hack for creating the class correctly when unpickling.
99+
step.__newargs = (vars,) + args, kwargs
100+
return step
101+
102+
# Hack for creating the class correctly when unpickling.
103+
def __getnewargs_ex__(self):
104+
return self.__newargs
105+
106+
@abstractmethod
107+
def step(self, point: PointType) -> Tuple[PointType, StatsType]:
108+
"""Perform a single step of the sampler."""
109+
110+
@staticmethod
111+
def competence(var, has_grad):
112+
return Competence.INCOMPATIBLE
113+
114+
@classmethod
115+
def _competence(cls, vars, have_grad):
116+
vars = np.atleast_1d(vars)
117+
have_grad = np.atleast_1d(have_grad)
118+
competences = []
119+
for var, has_grad in zip(vars, have_grad):
120+
try:
121+
competences.append(cls.competence(var, has_grad))
122+
except TypeError:
123+
competences.append(cls.competence(var))
124+
return competences
125+
126+
def stop_tuning(self):
127+
if hasattr(self, "tune"):
128+
self.tune = False
25129

26130

27131
class CompoundStep:

pymc/step_methods/hmc/hmc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import numpy as np
2020

2121
from pymc.stats.convergence import SamplerWarning
22-
from pymc.step_methods.arraystep import Competence
22+
from pymc.step_methods.compound import Competence
2323
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
2424
from pymc.step_methods.hmc.integration import IntegrationError, State
2525
from pymc.vartypes import discrete_types

pymc/step_methods/hmc/nuts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pymc.math import logbern
2222
from pymc.pytensorf import floatX
2323
from pymc.stats.convergence import SamplerWarning
24-
from pymc.step_methods.arraystep import Competence
24+
from pymc.step_methods.compound import Competence
2525
from pymc.step_methods.hmc import integration
2626
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
2727
from pymc.step_methods.hmc.integration import IntegrationError, State

pymc/step_methods/metropolis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@
3636
from pymc.step_methods.arraystep import (
3737
ArrayStep,
3838
ArrayStepShared,
39-
Competence,
4039
PopulationArrayStepShared,
4140
StatsType,
4241
metrop_select,
4342
)
43+
from pymc.step_methods.compound import Competence
4444

4545
__all__ = [
4646
"Metropolis",

pymc/step_methods/slicer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
from pymc.blocking import RaveledVars, StatsType
2323
from pymc.model import modelcontext
24-
from pymc.step_methods.arraystep import ArrayStep, Competence
24+
from pymc.step_methods.arraystep import ArrayStep
25+
from pymc.step_methods.compound import Competence
2526
from pymc.util import get_value_vars_from_user_vars
2627
from pymc.vartypes import continuous_types
2728

0 commit comments

Comments
 (0)