Skip to content

Commit da6eaab

Browse files
committed
Refactor get_steps to work with multivariate support shapes
1 parent f7a55c5 commit da6eaab

File tree

4 files changed

+183
-134
lines changed

4 files changed

+183
-134
lines changed

pymc/distributions/shape_utils.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,23 @@
2020
import warnings
2121

2222
from functools import singledispatch
23-
from typing import Optional, Sequence, Tuple, Union
23+
from typing import Any, Optional, Sequence, Tuple, Union
2424

2525
import numpy as np
2626

2727
from aesara import config
2828
from aesara import tensor as at
2929
from aesara.graph.basic import Variable
3030
from aesara.graph.op import Op, compute_test_value
31+
from aesara.raise_op import Assert
3132
from aesara.tensor.random.op import RandomVariable
3233
from aesara.tensor.shape import SpecifyShape
3334
from aesara.tensor.var import TensorVariable
3435
from typing_extensions import TypeAlias
3536

37+
from pymc.aesaraf import convert_observed_data
38+
from pymc.model import modelcontext
39+
3640
__all__ = [
3741
"to_tuple",
3842
"shapes_broadcasting",
@@ -666,3 +670,100 @@ def change_specify_shape_size(op, ss, new_size, expand) -> TensorVariable:
666670

667671
# specify_shape has a wrong signature https://github.com/aesara-devs/aesara/issues/1164
668672
return at.specify_shape(new_var, new_shapes) # type: ignore
673+
674+
675+
def get_support_shape(
676+
support_shape: Optional[Sequence[Union[int, np.ndarray, TensorVariable]]],
677+
*,
678+
shape: Optional[Shape] = None,
679+
dims: Optional[Dims] = None,
680+
observed: Optional[Any] = None,
681+
support_shape_offset: Sequence[int] = None,
682+
ndim_supp: int = 1,
683+
):
684+
"""Extract length of support shapes from shape / dims / observed information
685+
686+
Parameters
687+
----------
688+
support_shape:
689+
User-specified support shape for multivariate distribution
690+
shape:
691+
User-specified shape for multivariate distribution
692+
dims:
693+
User-specified dims for multivariate distribution
694+
observed:
695+
User-specified observed data from multivariate distribution
696+
support_shape_offset:
697+
Difference between last shape dimensions and the length of explicit support shapes in multivariate distribution, defaults to 0.
698+
For timeseries, this is shape[-1] = support_shape[-1] + 1
699+
ndim_supp:
700+
Number of support dimensions of the given multivariate distribution, defaults to 1
701+
702+
Returns
703+
-------
704+
support_shape
705+
Support shape, if specified directly by user, or inferred from the last dimensions of
706+
shape / dims / observed. When two sources of support shape information are provided,
707+
a symbolic Assert is added to ensure they are consistent.
708+
"""
709+
if support_shape_offset is None:
710+
support_shape_offset = [0] * ndim_supp
711+
inferred_support_shape = None
712+
713+
if shape is not None:
714+
shape = to_tuple(shape)
715+
assert isinstance(shape, tuple)
716+
inferred_support_shape = at.stack(
717+
[shape[-i - 1] - support_shape_offset[-i - 1] for i in range(ndim_supp)]
718+
)
719+
720+
if inferred_support_shape is None and dims is not None:
721+
dims = convert_dims(dims)
722+
assert isinstance(dims, tuple)
723+
model = modelcontext(None)
724+
inferred_support_shape = at.stack(
725+
[
726+
model.dim_lengths[dims[-i - 1]] - support_shape_offset[-i - 1] # type: ignore
727+
for i in range(ndim_supp)
728+
]
729+
)
730+
731+
if inferred_support_shape is None and observed is not None:
732+
observed = convert_observed_data(observed)
733+
inferred_support_shape = at.stack(
734+
[observed.shape[-i - 1] - support_shape_offset[-i - 1] for i in range(ndim_supp)]
735+
)
736+
737+
if inferred_support_shape is None:
738+
inferred_support_shape = support_shape
739+
# If there are two sources of information for the support shapes, assert they are consistent:
740+
elif support_shape is not None:
741+
inferred_support_shape = Assert(msg="Steps do not match last shape dimension")(
742+
inferred_support_shape, at.all(at.eq(inferred_support_shape, support_shape))
743+
)
744+
return inferred_support_shape
745+
746+
747+
def get_support_shape_1d(
748+
support_shape: Optional[Union[int, np.ndarray, TensorVariable]],
749+
*,
750+
shape: Optional[Shape] = None,
751+
dims: Optional[Dims] = None,
752+
observed: Optional[Any] = None,
753+
support_shape_offset: int = 0,
754+
):
755+
"""Helper function for cases when you just care about one dimension."""
756+
if support_shape is not None:
757+
support_shape_tuple = (support_shape,)
758+
759+
support_shape_tuple = get_support_shape(
760+
support_shape_tuple,
761+
shape=shape,
762+
dims=dims,
763+
observed=observed,
764+
support_shape_offset=(support_shape_offset,),
765+
)
766+
if support_shape_tuple is not None:
767+
(support_shape,) = support_shape_tuple
768+
769+
return support_shape

pymc/distributions/timeseries.py

Lines changed: 15 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import warnings
1515

16-
from typing import Any, Optional, Union
16+
from typing import Optional
1717

1818
import aesara
1919
import aesara.tensor as at
@@ -24,12 +24,11 @@
2424
from aesara import scan
2525
from aesara.graph import FunctionGraph, rewrite_graph
2626
from aesara.graph.basic import Node, clone_replace
27-
from aesara.raise_op import Assert
2827
from aesara.tensor import TensorVariable
2928
from aesara.tensor.random.op import RandomVariable
3029
from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding
3130

32-
from pymc.aesaraf import convert_observed_data, floatX, intX
31+
from pymc.aesaraf import floatX, intX
3332
from pymc.distributions import distribution, multivariate
3433
from pymc.distributions.continuous import Flat, Normal, get_tau_sigma
3534
from pymc.distributions.distribution import (
@@ -40,14 +39,11 @@
4039
)
4140
from pymc.distributions.logprob import ignore_logprob, logp
4241
from pymc.distributions.shape_utils import (
43-
Dims,
44-
Shape,
4542
_change_dist_size,
4643
change_dist_size,
47-
convert_dims,
44+
get_support_shape_1d,
4845
to_tuple,
4946
)
50-
from pymc.model import modelcontext
5147
from pymc.util import check_dist_not_registered
5248

5349
__all__ = [
@@ -61,61 +57,6 @@
6157
]
6258

6359

64-
def get_steps(
65-
steps: Optional[Union[int, np.ndarray, TensorVariable]],
66-
*,
67-
shape: Optional[Shape] = None,
68-
dims: Optional[Dims] = None,
69-
observed: Optional[Any] = None,
70-
step_shape_offset: int = 0,
71-
):
72-
"""Extract number of steps from shape / dims / observed information
73-
74-
Parameters
75-
----------
76-
steps:
77-
User specified steps for timeseries distribution
78-
shape:
79-
User specified shape for timeseries distribution
80-
dims:
81-
User specified dims for timeseries distribution
82-
observed:
83-
User specified observed data from timeseries distribution
84-
step_shape_offset:
85-
Difference between last shape dimension and number of steps in timeseries
86-
distribution, defaults to 0
87-
88-
Returns
89-
-------
90-
steps
91-
Steps, if specified directly by user, or inferred from the last dimension of
92-
shape / dims / observed. When two sources of step information are provided,
93-
a symbolic Assert is added to ensure they are consistent.
94-
"""
95-
inferred_steps = None
96-
if shape is not None:
97-
shape = to_tuple(shape)
98-
inferred_steps = shape[-1] - step_shape_offset
99-
100-
if inferred_steps is None and dims is not None:
101-
dims = convert_dims(dims)
102-
model = modelcontext(None)
103-
inferred_steps = model.dim_lengths[dims[-1]] - step_shape_offset
104-
105-
if inferred_steps is None and observed is not None:
106-
observed = convert_observed_data(observed)
107-
inferred_steps = observed.shape[-1] - step_shape_offset
108-
109-
if inferred_steps is None:
110-
inferred_steps = steps
111-
# If there are two sources of information for the steps, assert they are consistent
112-
elif steps is not None:
113-
inferred_steps = Assert(msg="Steps do not match last shape dimension")(
114-
inferred_steps, at.eq(inferred_steps, steps)
115-
)
116-
return inferred_steps
117-
118-
11960
class RandomWalkRV(SymbolicRandomVariable):
12061
"""RandomWalk Variable"""
12162

@@ -132,21 +73,21 @@ class RandomWalk(Distribution):
13273
rv_type = RandomWalkRV
13374

13475
def __new__(cls, *args, steps=None, **kwargs):
135-
steps = get_steps(
136-
steps=steps,
76+
steps = get_support_shape_1d(
77+
support_shape=steps,
13778
shape=None, # Shape will be checked in `cls.dist`
13879
dims=kwargs.get("dims", None),
13980
observed=kwargs.get("observed", None),
140-
step_shape_offset=1,
81+
support_shape_offset=1,
14182
)
14283
return super().__new__(cls, *args, steps=steps, **kwargs)
14384

14485
@classmethod
14586
def dist(cls, init_dist, innovation_dist, steps=None, **kwargs) -> at.TensorVariable:
146-
steps = get_steps(
147-
steps=steps,
87+
steps = get_support_shape_1d(
88+
support_shape=steps,
14889
shape=kwargs.get("shape"),
149-
step_shape_offset=1,
90+
support_shape_offset=1,
15091
)
15192
if steps is None:
15293
raise ValueError("Must specify steps or shape parameter")
@@ -391,12 +332,12 @@ class AR(Distribution):
391332
def __new__(cls, name, rho, *args, steps=None, constant=False, ar_order=None, **kwargs):
392333
rhos = at.atleast_1d(at.as_tensor_variable(floatX(rho)))
393334
ar_order = cls._get_ar_order(rhos=rhos, constant=constant, ar_order=ar_order)
394-
steps = get_steps(
395-
steps=steps,
335+
steps = get_support_shape_1d(
336+
support_shape=steps,
396337
shape=None, # Shape will be checked in `cls.dist`
397338
dims=kwargs.get("dims", None),
398339
observed=kwargs.get("observed", None),
399-
step_shape_offset=ar_order,
340+
support_shape_offset=ar_order,
400341
)
401342
return super().__new__(
402343
cls, name, rhos, *args, steps=steps, constant=constant, ar_order=ar_order, **kwargs
@@ -427,7 +368,9 @@ def dist(
427368
init_dist = kwargs.pop("init")
428369

429370
ar_order = cls._get_ar_order(rhos=rhos, constant=constant, ar_order=ar_order)
430-
steps = get_steps(steps=steps, shape=kwargs.get("shape", None), step_shape_offset=ar_order)
371+
steps = get_support_shape_1d(
372+
support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=ar_order
373+
)
431374
if steps is None:
432375
raise ValueError("Must specify steps or shape parameter")
433376
steps = at.as_tensor_variable(intX(steps), ndim=0)

pymc/tests/distributions/test_shape_utils.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
import numpy as np
1919
import pytest
2020

21-
from aesara import Mode
2221
from aesara import tensor as at
22+
from aesara.compile.mode import Mode
2323
from aesara.graph import Constant, ancestors
24+
from aesara.tensor import TensorVariable
2425
from aesara.tensor.random import normal
2526
from aesara.tensor.shape import SpecifyShape
2627

@@ -36,10 +37,12 @@
3637
convert_shape,
3738
convert_size,
3839
get_broadcastable_dist_samples,
40+
get_support_shape_1d,
3941
rv_size_is_none,
4042
shapes_broadcasting,
4143
to_tuple,
4244
)
45+
from pymc.model import Model
4346

4447
test_shapes = [
4548
(tuple(), (1,), (4,), (5, 4)),
@@ -622,3 +625,63 @@ def test_change_specify_shape_size_multivariate():
622625
new_x.eval({batch: 5, supp: 3}).shape == (10, 5, 5, 3)
623626
with pytest.raises(AssertionError, match=re.escape("expected (None, None, 5, 3)")):
624627
new_x.eval({batch: 6, supp: 3}).shape == (10, 5, 5, 3)
628+
629+
630+
@pytest.mark.parametrize(
631+
"steps, shape, step_shape_offset, expected_steps, consistent",
632+
[
633+
(10, None, 0, 10, True),
634+
(10, None, 1, 10, True),
635+
(None, (10,), 0, 10, True),
636+
(None, (10,), 1, 9, True),
637+
(None, (10, 5), 0, 5, True),
638+
(None, None, 0, None, True),
639+
(10, (10,), 0, 10, True),
640+
(10, (11,), 1, 10, True),
641+
(10, (5, 5), 0, 5, False),
642+
(10, (5, 10), 1, 9, False),
643+
],
644+
)
645+
@pytest.mark.parametrize("info_source", ("shape", "dims", "observed"))
646+
def test_get_support_shape_1d(
647+
info_source, steps, shape, step_shape_offset, expected_steps, consistent
648+
):
649+
if info_source == "shape":
650+
inferred_steps = get_support_shape_1d(
651+
support_shape=steps, shape=shape, support_shape_offset=step_shape_offset
652+
)
653+
654+
elif info_source == "dims":
655+
if shape is None:
656+
dims = None
657+
coords = {}
658+
else:
659+
dims = tuple(str(i) for i, shape in enumerate(shape))
660+
coords = {str(i): range(shape) for i, shape in enumerate(shape)}
661+
with Model(coords=coords):
662+
inferred_steps = get_support_shape_1d(
663+
support_shape=steps, dims=dims, support_shape_offset=step_shape_offset
664+
)
665+
666+
elif info_source == "observed":
667+
if shape is None:
668+
observed = None
669+
else:
670+
observed = np.zeros(shape)
671+
inferred_steps = get_support_shape_1d(
672+
support_shape=steps, observed=observed, support_shape_offset=step_shape_offset
673+
)
674+
675+
if not isinstance(inferred_steps, TensorVariable):
676+
assert inferred_steps == expected_steps
677+
else:
678+
if consistent:
679+
assert inferred_steps.eval() == expected_steps
680+
else:
681+
# check that inferred steps is still correct by ignoring the assert
682+
f = aesara.function(
683+
[], inferred_steps, mode=Mode().including("local_remove_all_assert")
684+
)
685+
assert f() == expected_steps
686+
with pytest.raises(AssertionError, match="Steps do not match"):
687+
inferred_steps.eval()

0 commit comments

Comments
 (0)