|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 |
| -from abc import ABC, abstractmethod |
16 |
| -from enum import IntEnum, unique |
17 |
| -from typing import Callable, Dict, List, Tuple, Union, cast |
| 15 | +from abc import abstractmethod |
| 16 | +from typing import Callable, List, Tuple, Union, cast |
18 | 17 |
|
19 | 18 | import numpy as np
|
20 | 19 |
|
21 | 20 | from numpy.random import uniform
|
22 |
| -from pytensor.graph.basic import Variable |
23 | 21 |
|
24 | 22 | from pymc.blocking import DictToArrayBijection, PointType, RaveledVars, StatsType
|
25 | 23 | from pymc.model import modelcontext
|
26 |
| -from pymc.step_methods.compound import CompoundStep |
| 24 | +from pymc.step_methods.compound import BlockedStep |
27 | 25 | from pymc.util import get_var_name
|
28 | 26 |
|
29 |
| -__all__ = ["ArrayStep", "ArrayStepShared", "metrop_select", "Competence"] |
30 |
| - |
31 |
| - |
32 |
| -@unique |
33 |
| -class Competence(IntEnum): |
34 |
| - """Enum for characterizing competence classes of step methods. |
35 |
| - Values include: |
36 |
| - 0: INCOMPATIBLE |
37 |
| - 1: COMPATIBLE |
38 |
| - 2: PREFERRED |
39 |
| - 3: IDEAL |
40 |
| - """ |
41 |
| - |
42 |
| - INCOMPATIBLE = 0 |
43 |
| - COMPATIBLE = 1 |
44 |
| - PREFERRED = 2 |
45 |
| - IDEAL = 3 |
46 |
| - |
47 |
| - |
48 |
| -class BlockedStep(ABC): |
49 |
| - |
50 |
| - stats_dtypes: List[Dict[str, type]] = [] |
51 |
| - vars: List[Variable] = [] |
52 |
| - |
53 |
| - def __new__(cls, *args, **kwargs): |
54 |
| - blocked = kwargs.get("blocked") |
55 |
| - if blocked is None: |
56 |
| - # Try to look up default value from class |
57 |
| - blocked = getattr(cls, "default_blocked", True) |
58 |
| - kwargs["blocked"] = blocked |
59 |
| - |
60 |
| - model = modelcontext(kwargs.get("model")) |
61 |
| - kwargs.update({"model": model}) |
62 |
| - |
63 |
| - # vars can either be first arg or a kwarg |
64 |
| - if "vars" not in kwargs and len(args) >= 1: |
65 |
| - vars = args[0] |
66 |
| - args = args[1:] |
67 |
| - elif "vars" in kwargs: |
68 |
| - vars = kwargs.pop("vars") |
69 |
| - else: # Assume all model variables |
70 |
| - vars = model.value_vars |
71 |
| - |
72 |
| - if not isinstance(vars, (tuple, list)): |
73 |
| - vars = [vars] |
74 |
| - |
75 |
| - if len(vars) == 0: |
76 |
| - raise ValueError("No free random variables to sample.") |
77 |
| - |
78 |
| - if not blocked and len(vars) > 1: |
79 |
| - # In this case we create a separate sampler for each var |
80 |
| - # and append them to a CompoundStep |
81 |
| - steps = [] |
82 |
| - for var in vars: |
83 |
| - step = super().__new__(cls) |
84 |
| - # If we don't return the instance we have to manually |
85 |
| - # call __init__ |
86 |
| - step.__init__([var], *args, **kwargs) |
87 |
| - # Hack for creating the class correctly when unpickling. |
88 |
| - step.__newargs = ([var],) + args, kwargs |
89 |
| - steps.append(step) |
90 |
| - |
91 |
| - return CompoundStep(steps) |
92 |
| - else: |
93 |
| - step = super().__new__(cls) |
94 |
| - # Hack for creating the class correctly when unpickling. |
95 |
| - step.__newargs = (vars,) + args, kwargs |
96 |
| - return step |
97 |
| - |
98 |
| - # Hack for creating the class correctly when unpickling. |
99 |
| - def __getnewargs_ex__(self): |
100 |
| - return self.__newargs |
101 |
| - |
102 |
| - @abstractmethod |
103 |
| - def step(self, point: PointType) -> Tuple[PointType, StatsType]: |
104 |
| - """Perform a single step of the sampler.""" |
105 |
| - |
106 |
| - @staticmethod |
107 |
| - def competence(var, has_grad): |
108 |
| - return Competence.INCOMPATIBLE |
109 |
| - |
110 |
| - @classmethod |
111 |
| - def _competence(cls, vars, have_grad): |
112 |
| - vars = np.atleast_1d(vars) |
113 |
| - have_grad = np.atleast_1d(have_grad) |
114 |
| - competences = [] |
115 |
| - for var, has_grad in zip(vars, have_grad): |
116 |
| - try: |
117 |
| - competences.append(cls.competence(var, has_grad)) |
118 |
| - except TypeError: |
119 |
| - competences.append(cls.competence(var)) |
120 |
| - return competences |
121 |
| - |
122 |
| - def stop_tuning(self): |
123 |
| - if hasattr(self, "tune"): |
124 |
| - self.tune = False |
| 27 | +__all__ = ["ArrayStep", "ArrayStepShared", "metrop_select"] |
125 | 28 |
|
126 | 29 |
|
127 | 30 | class ArrayStep(BlockedStep):
|
|
0 commit comments