Skip to content

Commit 75e9406

Browse files
committed
Add constant_fold helper
1 parent 7684ec6 commit 75e9406

File tree

6 files changed

+63
-21
lines changed

6 files changed

+63
-21
lines changed

docs/source/api/aesaraf.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Aesara utils
1616
floatX
1717
intX
1818
smartfloatX
19+
constant_fold
1920
CallableTensor
2021
join_nonshared_inputs
2122
make_shared_replacements

pymc/aesaraf.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from aesara import scalar
3535
from aesara.compile.mode import Mode, get_mode
3636
from aesara.gradient import grad
37-
from aesara.graph import node_rewriter
37+
from aesara.graph import node_rewriter, rewrite_graph
3838
from aesara.graph.basic import (
3939
Apply,
4040
Constant,
@@ -55,10 +55,13 @@
5555
RandomGeneratorSharedVariable,
5656
RandomStateSharedVariable,
5757
)
58+
from aesara.tensor.rewriting.basic import topo_constant_folding
59+
from aesara.tensor.rewriting.shape import ShapeFeature
5860
from aesara.tensor.sharedvar import SharedVariable
5961
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
6062
from aesara.tensor.var import TensorConstant, TensorVariable
6163

64+
from pymc.exceptions import NotConstantValueError
6265
from pymc.vartypes import continuous_types, isgenerator, typefilter
6366

6467
PotentialShapeType = Union[int, np.ndarray, Sequence[Union[int, Variable]], TensorVariable]
@@ -82,6 +85,7 @@
8285
"at_rng",
8386
"convert_observed_data",
8487
"compile_pymc",
88+
"constant_fold",
8589
]
8690

8791

@@ -971,3 +975,26 @@ def compile_pymc(
971975
**kwargs,
972976
)
973977
return aesara_function
978+
979+
980+
def constant_fold(xs: Sequence[TensorVariable]) -> Tuple[np.ndarray, ...]:
981+
"""Use constant folding to get constant values of a graph.
982+
983+
Parameters
984+
----------
985+
xs: Sequence of TensorVariable
986+
The variables that are to be constant folded
987+
988+
Raises
989+
------
990+
NotConstantValueError:
991+
If any of the variables cannot be successfully constant folded
992+
"""
993+
fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], clone=True)
994+
995+
folded_xs = rewrite_graph(fg, custom_rewrite=topo_constant_folding).outputs
996+
997+
if not all(isinstance(folded_x, Constant) for folded_x in folded_xs):
998+
raise NotConstantValueError
999+
1000+
return tuple(folded_x.data for folded_x in folded_xs)

pymc/distributions/logprob.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@
2727
from aeppl.tensor import MeasurableJoin
2828
from aeppl.transforms import TransformValuesRewrite
2929
from aesara import tensor as at
30-
from aesara.graph import FunctionGraph, rewrite_graph
3130
from aesara.graph.basic import graph_inputs, io_toposort
3231
from aesara.tensor.random.op import RandomVariable
33-
from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding
3432
from aesara.tensor.subtensor import (
3533
AdvancedIncSubtensor,
3634
AdvancedIncSubtensor1,
@@ -41,7 +39,8 @@
4139
)
4240
from aesara.tensor.var import TensorVariable
4341

44-
from pymc.aesaraf import floatX
42+
from pymc.aesaraf import constant_fold, floatX
43+
from pymc.exceptions import NotConstantValueError
4544

4645

4746
def _get_scaling(
@@ -338,12 +337,10 @@ def logprob_join_constant_shapes(op, values, axis, *base_vars, **kwargs):
338337

339338
base_var_shapes = [base_var.shape[axis] for base_var in base_vars]
340339

341-
shape_fg = FunctionGraph(
342-
outputs=base_var_shapes,
343-
features=[ShapeFeature()],
344-
clone=True,
345-
)
346-
base_var_shapes = rewrite_graph(shape_fg, custom_opt=topo_constant_folding).outputs
340+
try:
341+
base_var_shapes = constant_fold(base_var_shapes)
342+
except NotConstantValueError:
343+
pass
347344

348345
split_values = at.split(
349346
value,

pymc/distributions/timeseries.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,12 @@
2121

2222
from aeppl.abstract import _get_measurable_outputs
2323
from aeppl.logprob import _logprob
24-
from aesara.graph import FunctionGraph, rewrite_graph
2524
from aesara.graph.basic import Node, clone_replace
2625
from aesara.raise_op import Assert
2726
from aesara.tensor import TensorVariable
2827
from aesara.tensor.random.op import RandomVariable
29-
from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding
3028

31-
from pymc.aesaraf import convert_observed_data, floatX, intX
29+
from pymc.aesaraf import constant_fold, convert_observed_data, floatX, intX
3230
from pymc.distributions import distribution, multivariate
3331
from pymc.distributions.continuous import Flat, Normal, get_tau_sigma
3432
from pymc.distributions.distribution import (
@@ -46,6 +44,7 @@
4644
convert_dims,
4745
to_tuple,
4846
)
47+
from pymc.exceptions import NotConstantValueError
4948
from pymc.model import modelcontext
5049
from pymc.util import check_dist_not_registered
5150

@@ -472,14 +471,9 @@ def _get_ar_order(cls, rhos: TensorVariable, ar_order: Optional[int], constant:
472471
If inferred ar_order cannot be inferred from rhos or if it is less than 1
473472
"""
474473
if ar_order is None:
475-
shape_fg = FunctionGraph(
476-
outputs=[rhos.shape[-1]],
477-
features=[ShapeFeature()],
478-
clone=True,
479-
)
480-
(folded_shape,) = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs
481-
folded_shape = getattr(folded_shape, "data", None)
482-
if folded_shape is None:
474+
try:
475+
(folded_shape,) = constant_fold((rhos.shape[-1],))
476+
except NotConstantValueError:
483477
raise ValueError(
484478
"Could not infer ar_order from last dimension of rho. Pass it "
485479
"explictily or make sure rho have a static shape"

pymc/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,7 @@ def __init__(self, message, actual=None, expected=None):
7878

7979
class TruncationError(RuntimeError):
8080
"""Exception for errors generated from truncated graphs"""
81+
82+
83+
class NotConstantValueError(ValueError):
84+
pass

pymc/tests/test_aesaraf.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
from pymc.aesaraf import (
3737
compile_pymc,
38+
constant_fold,
3839
convert_observed_data,
3940
extract_obs_data,
4041
replace_rng_nodes,
@@ -45,6 +46,7 @@
4546
from pymc.distributions.dist_math import check_parameters
4647
from pymc.distributions.distribution import SymbolicRandomVariable
4748
from pymc.distributions.transforms import Interval
49+
from pymc.exceptions import NotConstantValueError
4850
from pymc.vartypes import int_types
4951

5052

@@ -610,3 +612,20 @@ def test_reseed_rngs():
610612
assert rng.get_value()._bit_generator.state == bit_generator.state
611613
else:
612614
assert rng.get_value().bit_generator.state == bit_generator.state
615+
616+
617+
def test_constant_fold():
618+
x = at.random.normal(size=(5,))
619+
y = at.arange(x.size)
620+
621+
res = constant_fold((y, y.shape))
622+
assert np.array_equal(res[0], np.arange(5))
623+
assert tuple(res[1]) == (5,)
624+
625+
626+
def test_constant_fold_error():
627+
x = at.vector("x")
628+
y = at.arange(x.size)
629+
630+
with pytest.raises(NotConstantValueError):
631+
constant_fold((y, y.shape))

0 commit comments

Comments
 (0)