41
41
from pytensor import tensor as pt
42
42
from pytensor .graph .op import compute_test_value
43
43
from pytensor .graph .rewriting .basic import node_rewriter
44
- from pytensor .tensor .basic import Join , MakeVector
44
+ from pytensor .tensor .basic import Alloc , Join , MakeVector
45
45
from pytensor .tensor .elemwise import DimShuffle
46
- from pytensor .tensor .extra_ops import BroadcastTo
47
46
from pytensor .tensor .random .op import RandomVariable
48
47
from pytensor .tensor .random .rewriting import (
49
48
local_dimshuffle_rv_lift ,
59
58
from pymc .logprob .utils import check_potential_measurability
60
59
61
60
62
- @node_rewriter ([BroadcastTo ])
61
+ @node_rewriter ([Alloc ])
63
62
def naive_bcast_rv_lift (fgraph , node ):
64
- """Lift a ``BroadcastTo `` through a ``RandomVariable`` ``Op``.
63
+ """Lift an ``Alloc `` through a ``RandomVariable`` ``Op``.
65
64
66
65
XXX: This implementation simply broadcasts the ``RandomVariable``'s
67
66
parameters, which won't always work (e.g. multivariate distributions).
@@ -73,7 +72,7 @@ def naive_bcast_rv_lift(fgraph, node):
73
72
"""
74
73
75
74
if not (
76
- isinstance (node .op , BroadcastTo )
75
+ isinstance (node .op , Alloc )
77
76
and node .inputs [0 ].owner
78
77
and isinstance (node .inputs [0 ].owner .op , RandomVariable )
79
78
):
@@ -93,7 +92,7 @@ def naive_bcast_rv_lift(fgraph, node):
93
92
return None
94
93
95
94
if not bcast_shape :
96
- # The `BroadcastTo ` is broadcasting a scalar to a scalar (i.e. doing nothing)
95
+ # The `Alloc ` is broadcasting a scalar to a scalar (i.e. doing nothing)
97
96
assert rv_var .ndim == 0
98
97
return [rv_var ]
99
98
0 commit comments