|
21 | 21 |
|
22 | 22 | from aeppl.abstract import _get_measurable_outputs
|
23 | 23 | from aeppl.logprob import _logprob
|
| 24 | +from aesara.graph import FunctionGraph, rewrite_graph |
24 | 25 | from aesara.graph.basic import Node, clone_replace
|
25 | 26 | from aesara.raise_op import Assert
|
26 | 27 | from aesara.tensor import TensorVariable
|
27 | 28 | from aesara.tensor.random.op import RandomVariable
|
| 29 | +from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding |
28 | 30 |
|
29 |
| -from pymc.aesaraf import constant_fold, convert_observed_data, floatX, intX |
| 31 | +from pymc.aesaraf import convert_observed_data, floatX, intX |
30 | 32 | from pymc.distributions import distribution, multivariate
|
31 | 33 | from pymc.distributions.continuous import Flat, Normal, get_tau_sigma
|
32 | 34 | from pymc.distributions.distribution import (
|
|
44 | 46 | convert_dims,
|
45 | 47 | to_tuple,
|
46 | 48 | )
|
47 |
| -from pymc.exceptions import NotConstantValueError |
48 | 49 | from pymc.model import modelcontext
|
49 | 50 | from pymc.util import check_dist_not_registered
|
50 | 51 |
|
@@ -471,9 +472,14 @@ def _get_ar_order(cls, rhos: TensorVariable, ar_order: Optional[int], constant:
|
471 | 472 | If inferred ar_order cannot be inferred from rhos or if it is less than 1
|
472 | 473 | """
|
473 | 474 | if ar_order is None:
|
474 |
| - try: |
475 |
| - (folded_shape,) = constant_fold((rhos.shape[-1],)) |
476 |
| - except NotConstantValueError: |
| 475 | + shape_fg = FunctionGraph( |
| 476 | + outputs=[rhos.shape[-1]], |
| 477 | + features=[ShapeFeature()], |
| 478 | + clone=True, |
| 479 | + ) |
| 480 | + (folded_shape,) = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs |
| 481 | + folded_shape = getattr(folded_shape, "data", None) |
| 482 | + if folded_shape is None: |
477 | 483 | raise ValueError(
|
478 | 484 | "Could not infer ar_order from last dimension of rho. Pass it "
|
479 | 485 | "explictily or make sure rho have a static shape"
|
|
0 commit comments