|
21 | 21 |
|
22 | 22 | from aeppl.abstract import _get_measurable_outputs
|
23 | 23 | from aeppl.logprob import _logprob
|
24 |
| -from aesara.graph import FunctionGraph, rewrite_graph |
25 | 24 | from aesara.graph.basic import Node, clone_replace
|
26 | 25 | from aesara.raise_op import Assert
|
27 | 26 | from aesara.tensor import TensorVariable
|
28 | 27 | from aesara.tensor.random.op import RandomVariable
|
29 |
| -from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding |
30 | 28 |
|
31 |
| -from pymc.aesaraf import convert_observed_data, floatX, intX |
| 29 | +from pymc.aesaraf import constant_fold, convert_observed_data, floatX, intX |
32 | 30 | from pymc.distributions import distribution, multivariate
|
33 | 31 | from pymc.distributions.continuous import Flat, Normal, get_tau_sigma
|
34 | 32 | from pymc.distributions.distribution import (
|
|
46 | 44 | convert_dims,
|
47 | 45 | to_tuple,
|
48 | 46 | )
|
| 47 | +from pymc.exceptions import NotConstantValueError |
49 | 48 | from pymc.model import modelcontext
|
50 | 49 | from pymc.util import check_dist_not_registered
|
51 | 50 |
|
@@ -472,14 +471,9 @@ def _get_ar_order(cls, rhos: TensorVariable, ar_order: Optional[int], constant:
|
472 | 471 | If inferred ar_order cannot be inferred from rhos or if it is less than 1
|
473 | 472 | """
|
474 | 473 | if ar_order is None:
|
475 |
| - shape_fg = FunctionGraph( |
476 |
| - outputs=[rhos.shape[-1]], |
477 |
| - features=[ShapeFeature()], |
478 |
| - clone=True, |
479 |
| - ) |
480 |
| - (folded_shape,) = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs |
481 |
| - folded_shape = getattr(folded_shape, "data", None) |
482 |
| - if folded_shape is None: |
| 474 | + try: |
| 475 | + (folded_shape,) = constant_fold((rhos.shape[-1],)) |
| 476 | + except NotConstantValueError: |
483 | 477 | raise ValueError(
|
484 | 478 | "Could not infer ar_order from last dimension of rho. Pass it "
|
485 | 479 | "explictily or make sure rho have a static shape"
|
|
0 commit comments