Skip to content

Commit ce7b81a

Browse files
ricardoV94lucianopaz
authored andcommitted
Revert "Add constant_fold helper"
This reverts commit 75e9406.
1 parent 75e9406 commit ce7b81a

File tree

6 files changed

+21
-63
lines changed

6 files changed

+21
-63
lines changed

docs/source/api/aesaraf.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ Aesara utils
1616
floatX
1717
intX
1818
smartfloatX
19-
constant_fold
2019
CallableTensor
2120
join_nonshared_inputs
2221
make_shared_replacements

pymc/aesaraf.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from aesara import scalar
3535
from aesara.compile.mode import Mode, get_mode
3636
from aesara.gradient import grad
37-
from aesara.graph import node_rewriter, rewrite_graph
37+
from aesara.graph import node_rewriter
3838
from aesara.graph.basic import (
3939
Apply,
4040
Constant,
@@ -55,13 +55,10 @@
5555
RandomGeneratorSharedVariable,
5656
RandomStateSharedVariable,
5757
)
58-
from aesara.tensor.rewriting.basic import topo_constant_folding
59-
from aesara.tensor.rewriting.shape import ShapeFeature
6058
from aesara.tensor.sharedvar import SharedVariable
6159
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
6260
from aesara.tensor.var import TensorConstant, TensorVariable
6361

64-
from pymc.exceptions import NotConstantValueError
6562
from pymc.vartypes import continuous_types, isgenerator, typefilter
6663

6764
PotentialShapeType = Union[int, np.ndarray, Sequence[Union[int, Variable]], TensorVariable]
@@ -85,7 +82,6 @@
8582
"at_rng",
8683
"convert_observed_data",
8784
"compile_pymc",
88-
"constant_fold",
8985
]
9086

9187

@@ -975,26 +971,3 @@ def compile_pymc(
975971
**kwargs,
976972
)
977973
return aesara_function
978-
979-
980-
def constant_fold(xs: Sequence[TensorVariable]) -> Tuple[np.ndarray, ...]:
981-
"""Use constant folding to get constant values of a graph.
982-
983-
Parameters
984-
----------
985-
xs: Sequence of TensorVariable
986-
The variables that are to be constant folded
987-
988-
Raises
989-
------
990-
NotConstantValueError:
991-
If any of the variables cannot be successfully constant folded
992-
"""
993-
fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], clone=True)
994-
995-
folded_xs = rewrite_graph(fg, custom_rewrite=topo_constant_folding).outputs
996-
997-
if not all(isinstance(folded_x, Constant) for folded_x in folded_xs):
998-
raise NotConstantValueError
999-
1000-
return tuple(folded_x.data for folded_x in folded_xs)

pymc/distributions/logprob.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727
from aeppl.tensor import MeasurableJoin
2828
from aeppl.transforms import TransformValuesRewrite
2929
from aesara import tensor as at
30+
from aesara.graph import FunctionGraph, rewrite_graph
3031
from aesara.graph.basic import graph_inputs, io_toposort
3132
from aesara.tensor.random.op import RandomVariable
33+
from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding
3234
from aesara.tensor.subtensor import (
3335
AdvancedIncSubtensor,
3436
AdvancedIncSubtensor1,
@@ -39,8 +41,7 @@
3941
)
4042
from aesara.tensor.var import TensorVariable
4143

42-
from pymc.aesaraf import constant_fold, floatX
43-
from pymc.exceptions import NotConstantValueError
44+
from pymc.aesaraf import floatX
4445

4546

4647
def _get_scaling(
@@ -337,10 +338,12 @@ def logprob_join_constant_shapes(op, values, axis, *base_vars, **kwargs):
337338

338339
base_var_shapes = [base_var.shape[axis] for base_var in base_vars]
339340

340-
try:
341-
base_var_shapes = constant_fold(base_var_shapes)
342-
except NotConstantValueError:
343-
pass
341+
shape_fg = FunctionGraph(
342+
outputs=base_var_shapes,
343+
features=[ShapeFeature()],
344+
clone=True,
345+
)
346+
base_var_shapes = rewrite_graph(shape_fg, custom_opt=topo_constant_folding).outputs
344347

345348
split_values = at.split(
346349
value,

pymc/distributions/timeseries.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@
2121

2222
from aeppl.abstract import _get_measurable_outputs
2323
from aeppl.logprob import _logprob
24+
from aesara.graph import FunctionGraph, rewrite_graph
2425
from aesara.graph.basic import Node, clone_replace
2526
from aesara.raise_op import Assert
2627
from aesara.tensor import TensorVariable
2728
from aesara.tensor.random.op import RandomVariable
29+
from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding
2830

29-
from pymc.aesaraf import constant_fold, convert_observed_data, floatX, intX
31+
from pymc.aesaraf import convert_observed_data, floatX, intX
3032
from pymc.distributions import distribution, multivariate
3133
from pymc.distributions.continuous import Flat, Normal, get_tau_sigma
3234
from pymc.distributions.distribution import (
@@ -44,7 +46,6 @@
4446
convert_dims,
4547
to_tuple,
4648
)
47-
from pymc.exceptions import NotConstantValueError
4849
from pymc.model import modelcontext
4950
from pymc.util import check_dist_not_registered
5051

@@ -471,9 +472,14 @@ def _get_ar_order(cls, rhos: TensorVariable, ar_order: Optional[int], constant:
471472
If inferred ar_order cannot be inferred from rhos or if it is less than 1
472473
"""
473474
if ar_order is None:
474-
try:
475-
(folded_shape,) = constant_fold((rhos.shape[-1],))
476-
except NotConstantValueError:
475+
shape_fg = FunctionGraph(
476+
outputs=[rhos.shape[-1]],
477+
features=[ShapeFeature()],
478+
clone=True,
479+
)
480+
(folded_shape,) = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs
481+
folded_shape = getattr(folded_shape, "data", None)
482+
if folded_shape is None:
477483
raise ValueError(
478484
"Could not infer ar_order from last dimension of rho. Pass it "
479485
"explictily or make sure rho have a static shape"

pymc/exceptions.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,3 @@ def __init__(self, message, actual=None, expected=None):
7878

7979
class TruncationError(RuntimeError):
8080
"""Exception for errors generated from truncated graphs"""
81-
82-
83-
class NotConstantValueError(ValueError):
84-
pass

pymc/tests/test_aesaraf.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535

3636
from pymc.aesaraf import (
3737
compile_pymc,
38-
constant_fold,
3938
convert_observed_data,
4039
extract_obs_data,
4140
replace_rng_nodes,
@@ -46,7 +45,6 @@
4645
from pymc.distributions.dist_math import check_parameters
4746
from pymc.distributions.distribution import SymbolicRandomVariable
4847
from pymc.distributions.transforms import Interval
49-
from pymc.exceptions import NotConstantValueError
5048
from pymc.vartypes import int_types
5149

5250

@@ -612,20 +610,3 @@ def test_reseed_rngs():
612610
assert rng.get_value()._bit_generator.state == bit_generator.state
613611
else:
614612
assert rng.get_value().bit_generator.state == bit_generator.state
615-
616-
617-
def test_constant_fold():
618-
x = at.random.normal(size=(5,))
619-
y = at.arange(x.size)
620-
621-
res = constant_fold((y, y.shape))
622-
assert np.array_equal(res[0], np.arange(5))
623-
assert tuple(res[1]) == (5,)
624-
625-
626-
def test_constant_fold_error():
627-
x = at.vector("x")
628-
y = at.arange(x.size)
629-
630-
with pytest.raises(NotConstantValueError):
631-
constant_fold((y, y.shape))

0 commit comments

Comments
 (0)