14
14
15
15
from abc import ABC , abstractmethod
16
16
from enum import IntEnum , unique
17
- from typing import Dict , List , Tuple , TypeVar , Union
17
+ from typing import Callable , Dict , List , Tuple , Union , cast
18
18
19
19
import numpy as np
20
20
21
21
from aesara .graph .basic import Variable
22
22
from numpy .random import uniform
23
23
24
- from pymc .blocking import DictToArrayBijection , PointType , RaveledVars
24
+ from pymc .blocking import DictToArrayBijection , PointType , RaveledVars , StatsType
25
25
from pymc .model import modelcontext
26
26
from pymc .step_methods .compound import CompoundStep
27
27
from pymc .util import get_var_name
28
28
29
29
__all__ = ["ArrayStep" , "ArrayStepShared" , "metrop_select" , "Competence" ]
30
30
31
- StatsType = TypeVar ("StatsType" )
32
-
33
31
34
32
@unique
35
33
class Competence (IntEnum ):
@@ -49,7 +47,6 @@ class Competence(IntEnum):
49
47
50
48
class BlockedStep (ABC ):
51
49
52
- generates_stats = False
53
50
stats_dtypes : List [Dict [str , type ]] = []
54
51
vars : List [Variable ] = []
55
52
@@ -103,7 +100,7 @@ def __getnewargs_ex__(self):
103
100
return self .__newargs
104
101
105
102
@abstractmethod
106
- def step (point : PointType , * args , ** kwargs ) -> Union [ PointType , Tuple [PointType , StatsType ] ]:
103
+ def step (self , point : PointType ) -> Tuple [PointType , StatsType ]:
107
104
"""Perform a single step of the sampler."""
108
105
109
106
@staticmethod
@@ -146,35 +143,28 @@ def __init__(self, vars, fs, allvars=False, blocked=True):
146
143
self .allvars = allvars
147
144
self .blocked = blocked
148
145
149
- def step (self , point : PointType ):
146
+ def step (self , point : PointType ) -> Tuple [ PointType , StatsType ] :
150
147
151
- partial_funcs_and_point = [DictToArrayBijection .mapf (x , start_point = point ) for x in self .fs ]
148
+ partial_funcs_and_point : List [Union [Callable , PointType ]] = [
149
+ DictToArrayBijection .mapf (x , start_point = point ) for x in self .fs
150
+ ]
152
151
if self .allvars :
153
152
partial_funcs_and_point .append (point )
154
153
155
- apoint = DictToArrayBijection .map ({v .name : point [v .name ] for v in self .vars })
156
- step_res = self .astep (apoint , * partial_funcs_and_point )
157
-
158
- if self .generates_stats :
159
- apoint_new , stats = step_res
160
- else :
161
- apoint_new = step_res
154
+ var_dict = {cast (str , v .name ): point [cast (str , v .name )] for v in self .vars }
155
+ apoint = DictToArrayBijection .map (var_dict )
156
+ apoint_new , stats = self .astep (apoint , * partial_funcs_and_point )
162
157
163
158
if not isinstance (apoint_new , RaveledVars ):
164
159
# We assume that the mapping has stayed the same
165
160
apoint_new = RaveledVars (apoint_new , apoint .point_map_info )
166
161
167
162
point_new = DictToArrayBijection .rmap (apoint_new , start_point = point )
168
163
169
- if self .generates_stats :
170
- return point_new , stats
171
-
172
- return point_new
164
+ return point_new , stats
173
165
174
166
@abstractmethod
175
- def astep (
176
- self , apoint : RaveledVars , point : PointType , * args
177
- ) -> Union [RaveledVars , Tuple [RaveledVars , StatsType ]]:
167
+ def astep (self , apoint : RaveledVars , * args ) -> Tuple [RaveledVars , StatsType ]:
178
168
"""Perform a single sample step in a raveled and concatenated parameter space."""
179
169
180
170
@@ -198,30 +188,27 @@ def __init__(self, vars, shared, blocked=True):
198
188
self .shared = {get_var_name (var ): shared for var , shared in shared .items ()}
199
189
self .blocked = blocked
200
190
201
- def step (self , point ) :
191
+ def step (self , point : PointType ) -> Tuple [ PointType , StatsType ] :
202
192
203
193
for name , shared_var in self .shared .items ():
204
194
shared_var .set_value (point [name ])
205
195
206
- q = DictToArrayBijection .map ({v .name : point [v .name ] for v in self .vars })
196
+ var_dict = {cast (str , v .name ): point [cast (str , v .name )] for v in self .vars }
197
+ q = DictToArrayBijection .map (var_dict )
207
198
208
- step_res = self .astep (q )
209
-
210
- if self .generates_stats :
211
- apoint , stats = step_res
212
- else :
213
- apoint = step_res
199
+ apoint , stats = self .astep (q )
214
200
215
201
if not isinstance (apoint , RaveledVars ):
216
202
# We assume that the mapping has stayed the same
217
203
apoint = RaveledVars (apoint , q .point_map_info )
218
204
219
205
new_point = DictToArrayBijection .rmap (apoint , start_point = point )
220
206
221
- if self .generates_stats :
222
- return new_point , stats
207
+ return new_point , stats
223
208
224
- return new_point
209
+ @abstractmethod
210
+ def astep (self , q0 : RaveledVars ) -> Tuple [RaveledVars , StatsType ]:
211
+ """Perform a single sample step in a raveled and concatenated parameter space."""
225
212
226
213
227
214
class PopulationArrayStepShared (ArrayStepShared ):
@@ -281,7 +268,7 @@ def __init__(
281
268
282
269
super ().__init__ (vars , func ._extra_vars_shared , blocked )
283
270
284
- def step (self , point ):
271
+ def step (self , point ) -> Tuple [ PointType , StatsType ] :
285
272
self ._logp_dlogp_func ._extra_are_set = True
286
273
return super ().step (point )
287
274
0 commit comments