13
13
# limitations under the License.
14
14
import warnings
15
15
16
- from typing import Any , Optional , Union
16
+ from typing import Optional
17
17
18
18
import aesara
19
19
import aesara .tensor as at
24
24
from aesara import scan
25
25
from aesara .graph import FunctionGraph , rewrite_graph
26
26
from aesara .graph .basic import Node , clone_replace
27
- from aesara .raise_op import Assert
28
27
from aesara .tensor import TensorVariable
29
28
from aesara .tensor .random .op import RandomVariable
30
29
from aesara .tensor .rewriting .basic import ShapeFeature , topo_constant_folding
31
30
32
- from pymc .aesaraf import convert_observed_data , floatX , intX
31
+ from pymc .aesaraf import floatX , intX
33
32
from pymc .distributions import distribution , multivariate
34
33
from pymc .distributions .continuous import Flat , Normal , get_tau_sigma
35
34
from pymc .distributions .distribution import (
40
39
)
41
40
from pymc .distributions .logprob import ignore_logprob , logp
42
41
from pymc .distributions .shape_utils import (
43
- Dims ,
44
- Shape ,
45
42
_change_dist_size ,
46
43
change_dist_size ,
47
- convert_dims ,
44
+ get_support_shape_1d ,
48
45
to_tuple ,
49
46
)
50
- from pymc .model import modelcontext
51
47
from pymc .util import check_dist_not_registered
52
48
53
49
__all__ = [
61
57
]
62
58
63
59
64
- def get_steps (
65
- steps : Optional [Union [int , np .ndarray , TensorVariable ]],
66
- * ,
67
- shape : Optional [Shape ] = None ,
68
- dims : Optional [Dims ] = None ,
69
- observed : Optional [Any ] = None ,
70
- step_shape_offset : int = 0 ,
71
- ):
72
- """Extract number of steps from shape / dims / observed information
73
-
74
- Parameters
75
- ----------
76
- steps:
77
- User specified steps for timeseries distribution
78
- shape:
79
- User specified shape for timeseries distribution
80
- dims:
81
- User specified dims for timeseries distribution
82
- observed:
83
- User specified observed data from timeseries distribution
84
- step_shape_offset:
85
- Difference between last shape dimension and number of steps in timeseries
86
- distribution, defaults to 0
87
-
88
- Returns
89
- -------
90
- steps
91
- Steps, if specified directly by user, or inferred from the last dimension of
92
- shape / dims / observed. When two sources of step information are provided,
93
- a symbolic Assert is added to ensure they are consistent.
94
- """
95
- inferred_steps = None
96
- if shape is not None :
97
- shape = to_tuple (shape )
98
- inferred_steps = shape [- 1 ] - step_shape_offset
99
-
100
- if inferred_steps is None and dims is not None :
101
- dims = convert_dims (dims )
102
- model = modelcontext (None )
103
- inferred_steps = model .dim_lengths [dims [- 1 ]] - step_shape_offset
104
-
105
- if inferred_steps is None and observed is not None :
106
- observed = convert_observed_data (observed )
107
- inferred_steps = observed .shape [- 1 ] - step_shape_offset
108
-
109
- if inferred_steps is None :
110
- inferred_steps = steps
111
- # If there are two sources of information for the steps, assert they are consistent
112
- elif steps is not None :
113
- inferred_steps = Assert (msg = "Steps do not match last shape dimension" )(
114
- inferred_steps , at .eq (inferred_steps , steps )
115
- )
116
- return inferred_steps
117
-
118
-
119
60
class RandomWalkRV (SymbolicRandomVariable ):
120
61
"""RandomWalk Variable"""
121
62
@@ -132,21 +73,21 @@ class RandomWalk(Distribution):
132
73
rv_type = RandomWalkRV
133
74
134
75
def __new__ (cls , * args , steps = None , ** kwargs ):
135
- steps = get_steps (
136
- steps = steps ,
76
+ steps = get_support_shape_1d (
77
+ support_shape = steps ,
137
78
shape = None , # Shape will be checked in `cls.dist`
138
79
dims = kwargs .get ("dims" , None ),
139
80
observed = kwargs .get ("observed" , None ),
140
- step_shape_offset = 1 ,
81
+ support_shape_offset = 1 ,
141
82
)
142
83
return super ().__new__ (cls , * args , steps = steps , ** kwargs )
143
84
144
85
@classmethod
145
86
def dist (cls , init_dist , innovation_dist , steps = None , ** kwargs ) -> at .TensorVariable :
146
- steps = get_steps (
147
- steps = steps ,
87
+ steps = get_support_shape_1d (
88
+ support_shape = steps ,
148
89
shape = kwargs .get ("shape" ),
149
- step_shape_offset = 1 ,
90
+ support_shape_offset = 1 ,
150
91
)
151
92
if steps is None :
152
93
raise ValueError ("Must specify steps or shape parameter" )
@@ -391,12 +332,12 @@ class AR(Distribution):
391
332
def __new__ (cls , name , rho , * args , steps = None , constant = False , ar_order = None , ** kwargs ):
392
333
rhos = at .atleast_1d (at .as_tensor_variable (floatX (rho )))
393
334
ar_order = cls ._get_ar_order (rhos = rhos , constant = constant , ar_order = ar_order )
394
- steps = get_steps (
395
- steps = steps ,
335
+ steps = get_support_shape_1d (
336
+ support_shape = steps ,
396
337
shape = None , # Shape will be checked in `cls.dist`
397
338
dims = kwargs .get ("dims" , None ),
398
339
observed = kwargs .get ("observed" , None ),
399
- step_shape_offset = ar_order ,
340
+ support_shape_offset = ar_order ,
400
341
)
401
342
return super ().__new__ (
402
343
cls , name , rhos , * args , steps = steps , constant = constant , ar_order = ar_order , ** kwargs
@@ -427,7 +368,9 @@ def dist(
427
368
init_dist = kwargs .pop ("init" )
428
369
429
370
ar_order = cls ._get_ar_order (rhos = rhos , constant = constant , ar_order = ar_order )
430
- steps = get_steps (steps = steps , shape = kwargs .get ("shape" , None ), step_shape_offset = ar_order )
371
+ steps = get_support_shape_1d (
372
+ support_shape = steps , shape = kwargs .get ("shape" , None ), support_shape_offset = ar_order
373
+ )
431
374
if steps is None :
432
375
raise ValueError ("Must specify steps or shape parameter" )
433
376
steps = at .as_tensor_variable (intX (steps ), ndim = 0 )
0 commit comments