22from functools import wraps
33from typing import Callable , Dict , Iterable , List , Optional , Sequence , Set , Tuple , Union
44
5- from aeppl .opt import PreserveRVMappings
5+ from aeppl .rewriting import PreserveRVMappings
66from aesara .compile .builders import OpFromGraph
77from aesara .compile .mode import optdb
88from aesara .graph .basic import Apply , Variable , clone_replace , io_toposort
99from aesara .graph .features import AlreadyThere , Feature
1010from aesara .graph .fg import FunctionGraph
1111from aesara .graph .op import Op
12- from aesara .graph .opt import in2out , local_optimizer
13- from aesara .graph .optdb import SequenceDB
14- from aesara .tensor .basic_opt import ShapeFeature
12+ from aesara .graph .rewriting .basic import in2out , node_rewriter
13+ from aesara .graph .rewriting .db import SequenceDB
1514from aesara .tensor .elemwise import DimShuffle , Elemwise
1615from aesara .tensor .random .op import RandomVariable
1716from aesara .tensor .random .utils import RandomStream
17+ from aesara .tensor .rewriting .basic import ShapeFeature
1818from aesara .tensor .var import TensorVariable
1919from cons .core import _car
2020from unification .core import _unify
2525SamplerFunctionType = Callable [
2626 [FunctionGraph , Apply , RandomStream ], SamplerFunctionReturnType
2727]
28- LocalOptimizerReturnType = Optional [Union [Dict [Variable , Variable ], Sequence [Variable ]]]
28+ LocalRewriterReturnType = Optional [Union [Dict [Variable , Variable ], Sequence [Variable ]]]
2929
3030sampler_ir_db = SequenceDB ()
3131sampler_ir_db .name = "sampler_ir_db"
@@ -90,7 +90,7 @@ def construct_ir_fgraph(
9090 # Update `obs_rvs_to_values` so that it uses the new cloned variables
9191 obs_rvs_to_values = {memo [k ]: v for k , v in obs_rvs_to_values .items ()}
9292
93- sampler_ir_db .query ("+basic" ).optimize (fgraph )
93+ sampler_ir_db .query ("+basic" ).rewrite (fgraph )
9494
9595 new_to_old_rvs = {
9696 new_rv : old_rv for old_rv , new_rv in zip (rv_outputs , fgraph .outputs )
@@ -123,7 +123,7 @@ def on_attach(self, fgraph: FunctionGraph):
123123
124124
125125def sampler_finder (tracks : Optional [Sequence [Union [Op , type ]]]):
126- """Construct a `LocalOptimizer ` that identifies sample steps.
126+ """Construct a `NodeRewriter ` that identifies sample steps.
127127
128128 This is a decorator that is used as follows:
129129
@@ -140,11 +140,11 @@ def local_horseshoe_posterior(fgraph, node, srng):
140140 """
141141
142142 def decorator (f : SamplerFunctionType ):
143- @local_optimizer (tracks )
143+ @node_rewriter (tracks )
144144 @wraps (f )
145145 def sampler_finder (
146146 fgraph : FunctionGraph , node : Apply
147- ) -> LocalOptimizerReturnType :
147+ ) -> LocalRewriterReturnType :
148148 sampler_mappings = getattr (fgraph , "sampler_mappings" , None )
149149
150150 # TODO: This assumes that `node` is a `RandomVariable`-generated `Apply` node
@@ -249,7 +249,7 @@ def car_SubsumingElemwise(x):
249249_car .add ((SubsumingElemwise ,), car_SubsumingElemwise )
250250
251251
252- @local_optimizer ([Elemwise ])
252+ @node_rewriter ([Elemwise ])
253253def local_elemwise_dimshuffle_subsume (fgraph , node ):
254254 r"""This rewrite converts `DimShuffle`s in the `Elemwise` inputs into a single `Op`.
255255
@@ -359,7 +359,7 @@ def local_elemwise_dimshuffle_subsume(fgraph, node):
359359)
360360
361361
362- @local_optimizer ([Elemwise ])
362+ @node_rewriter ([Elemwise ])
363363def inline_SubsumingElemwise (fgraph , node ):
364364
365365 op = node .op
0 commit comments