Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit da03cac

Browse files
Change use of "opt[imize]" to "rewrite"
1 parent b461e81 commit da03cac

File tree

11 files changed

+49
-44
lines changed

11 files changed

+49
-44
lines changed

aemcmc/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from aesara.tensor.random.utils import RandomStream
66
from aesara.tensor.var import TensorVariable
77

8-
from aemcmc.opt import (
8+
from aemcmc.rewriting import (
99
SamplerTracker,
1010
construct_ir_fgraph,
1111
expand_subsumptions,
@@ -41,7 +41,7 @@ def construct_sampler(
4141

4242
fgraph.attach_feature(SamplerTracker(srng))
4343

44-
_ = sampler_rewrites_db.query("+basic").optimize(fgraph)
44+
_ = sampler_rewrites_db.query("+basic").rewrite(fgraph)
4545

4646
random_vars = tuple(rv for rv in fgraph.outputs if rv not in obs_rvs_to_values)
4747

@@ -87,7 +87,7 @@ def construct_sampler(
8787
# Update the other sampled random variables in this step's graph
8888
sfgraph.replace_all(list(posterior_sample_steps.items()), import_missing=True)
8989

90-
expand_subsumptions.optimize(sfgraph)
90+
expand_subsumptions.rewrite(sfgraph)
9191

9292
step = sfgraph.outputs[0]
9393

aemcmc/conjugates.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import aesara.tensor as at
2-
from aesara.graph.opt import in2out, local_optimizer
3-
from aesara.graph.optdb import LocalGroupDB
4-
from aesara.graph.unify import eval_if_etuple
2+
from aesara.graph.rewriting.basic import in2out, node_rewriter
3+
from aesara.graph.rewriting.db import LocalGroupDB
4+
from aesara.graph.rewriting.unify import eval_if_etuple
55
from aesara.tensor.random.basic import BinomialRV
66
from etuples import etuple, etuplize
77
from kanren import eq, lall, run
88
from unification import var
99

10-
from aemcmc.opt import sampler_finder_db
10+
from aemcmc.rewriting import sampler_finder_db
1111

1212

1313
def beta_binomial_conjugateo(observed_val, observed_rv_expr, posterior_expr):
@@ -66,7 +66,7 @@ def beta_binomial_conjugateo(observed_val, observed_rv_expr, posterior_expr):
6666
)
6767

6868

69-
@local_optimizer([BinomialRV])
69+
@node_rewriter([BinomialRV])
7070
def local_beta_binomial_posterior(fgraph, node):
7171

7272
sampler_mappings = getattr(fgraph, "sampler_mappings", None)
@@ -98,7 +98,7 @@ def local_beta_binomial_posterior(fgraph, node):
9898
return rv_var.owner.outputs
9999

100100

101-
conjugates_db = LocalGroupDB(apply_all_opts=True)
101+
conjugates_db = LocalGroupDB(apply_all_rewrites=True)
102102
conjugates_db.name = "conjugates_db"
103103
conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic")
104104

aemcmc/gibbs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import aesara
44
import aesara.tensor as at
55
from aesara.graph.basic import Variable
6-
from aesara.graph.opt import in2out
7-
from aesara.graph.optdb import LocalGroupDB
8-
from aesara.graph.unify import eval_if_etuple
6+
from aesara.graph.rewriting.basic import in2out
7+
from aesara.graph.rewriting.db import LocalGroupDB
8+
from aesara.graph.rewriting.unify import eval_if_etuple
99
from aesara.ifelse import ifelse
1010
from aesara.tensor.math import Dot
1111
from aesara.tensor.random import RandomStream
@@ -19,9 +19,9 @@
1919
multivariate_normal_rue2005,
2020
polyagamma,
2121
)
22-
from aemcmc.opt import sampler_finder, sampler_finder_db
22+
from aemcmc.rewriting import sampler_finder, sampler_finder_db
2323

24-
gibbs_db = LocalGroupDB(apply_all_opts=True)
24+
gibbs_db = LocalGroupDB(apply_all_rewrites=True)
2525
gibbs_db.name = "gibbs_db"
2626

2727

aemcmc/nuts.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from aehmc import nuts as aehmc_nuts
55
from aehmc.utils import RaveledParamsMap
66
from aeppl import joint_logprob
7-
from aeppl.transforms import RVTransform, TransformValuesOpt, _default_transformed_rv
7+
from aeppl.transforms import (
8+
RVTransform,
9+
TransformValuesRewrite,
10+
_default_transformed_rv,
11+
)
812
from aesara.tensor.random import RandomStream
913
from aesara.tensor.var import TensorVariable
1014

@@ -62,7 +66,7 @@ def nuts(
6266
}
6367

6468
logprob_sum = joint_logprob(
65-
model.rvs_to_values, extra_rewrites=TransformValuesOpt(transforms)
69+
model.rvs_to_values, extra_rewrites=TransformValuesRewrite(transforms)
6670
)
6771

6872
# Then we transform the value variables.

aemcmc/opt.py renamed to aemcmc/rewriting.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,19 @@
22
from functools import wraps
33
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
44

5-
from aeppl.opt import PreserveRVMappings
5+
from aeppl.rewriting import PreserveRVMappings
66
from aesara.compile.builders import OpFromGraph
77
from aesara.compile.mode import optdb
88
from aesara.graph.basic import Apply, Variable, clone_replace, io_toposort
99
from aesara.graph.features import AlreadyThere, Feature
1010
from aesara.graph.fg import FunctionGraph
1111
from aesara.graph.op import Op
12-
from aesara.graph.opt import in2out, local_optimizer
13-
from aesara.graph.optdb import SequenceDB
14-
from aesara.tensor.basic_opt import ShapeFeature
12+
from aesara.graph.rewriting.basic import in2out, node_rewriter
13+
from aesara.graph.rewriting.db import SequenceDB
1514
from aesara.tensor.elemwise import DimShuffle, Elemwise
1615
from aesara.tensor.random.op import RandomVariable
1716
from aesara.tensor.random.utils import RandomStream
17+
from aesara.tensor.rewriting.basic import ShapeFeature
1818
from aesara.tensor.var import TensorVariable
1919
from cons.core import _car
2020
from unification.core import _unify
@@ -25,7 +25,7 @@
2525
SamplerFunctionType = Callable[
2626
[FunctionGraph, Apply, RandomStream], SamplerFunctionReturnType
2727
]
28-
LocalOptimizerReturnType = Optional[Union[Dict[Variable, Variable], Sequence[Variable]]]
28+
LocalRewriterReturnType = Optional[Union[Dict[Variable, Variable], Sequence[Variable]]]
2929

3030
sampler_ir_db = SequenceDB()
3131
sampler_ir_db.name = "sampler_ir_db"
@@ -90,7 +90,7 @@ def construct_ir_fgraph(
9090
# Update `obs_rvs_to_values` so that it uses the new cloned variables
9191
obs_rvs_to_values = {memo[k]: v for k, v in obs_rvs_to_values.items()}
9292

93-
sampler_ir_db.query("+basic").optimize(fgraph)
93+
sampler_ir_db.query("+basic").rewrite(fgraph)
9494

9595
new_to_old_rvs = {
9696
new_rv: old_rv for old_rv, new_rv in zip(rv_outputs, fgraph.outputs)
@@ -123,7 +123,7 @@ def on_attach(self, fgraph: FunctionGraph):
123123

124124

125125
def sampler_finder(tracks: Optional[Sequence[Union[Op, type]]]):
126-
"""Construct a `LocalOptimizer` that identifies sample steps.
126+
"""Construct a `NodeRewriter` that identifies sample steps.
127127
128128
This is a decorator that is used as follows:
129129
@@ -140,11 +140,11 @@ def local_horseshoe_posterior(fgraph, node, srng):
140140
"""
141141

142142
def decorator(f: SamplerFunctionType):
143-
@local_optimizer(tracks)
143+
@node_rewriter(tracks)
144144
@wraps(f)
145145
def sampler_finder(
146146
fgraph: FunctionGraph, node: Apply
147-
) -> LocalOptimizerReturnType:
147+
) -> LocalRewriterReturnType:
148148
sampler_mappings = getattr(fgraph, "sampler_mappings", None)
149149

150150
# TODO: This assumes that `node` is a `RandomVariable`-generated `Apply` node
@@ -249,7 +249,7 @@ def car_SubsumingElemwise(x):
249249
_car.add((SubsumingElemwise,), car_SubsumingElemwise)
250250

251251

252-
@local_optimizer([Elemwise])
252+
@node_rewriter([Elemwise])
253253
def local_elemwise_dimshuffle_subsume(fgraph, node):
254254
r"""This rewrite converts `DimShuffle`s in the `Elemwise` inputs into a single `Op`.
255255
@@ -359,7 +359,7 @@ def local_elemwise_dimshuffle_subsume(fgraph, node):
359359
)
360360

361361

362-
@local_optimizer([Elemwise])
362+
@node_rewriter([Elemwise])
363363
def inline_SubsumingElemwise(fgraph, node):
364364

365365
op = node.op

environment.yml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@ dependencies:
1111
- compilers
1212
- numpy>=1.18.1
1313
- scipy>=1.4.0
14-
- aesara>=2.6.6
15-
- aeppl>=0.0.31
16-
- etuples
14+
- aesara>=2.8.0
15+
- aeppl>=0.0.35
16+
- aehmc>=0.0.9
17+
- polyagamma>=1.3.2
18+
- cons
1719
- logical-unification
20+
- etuples
1821
- miniKanren
19-
- cons
20-
- polyagamma>=1.3.2
2122
# Intel BLAS
2223
- mkl
2324
- mkl-service

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
install_requires=[
1818
"numpy>=1.18.1",
1919
"scipy>=1.4.0",
20-
"aesara>=2.6.6",
21-
"aeppl>=0.0.31",
22-
"aehmc>=0.0.6",
20+
"aesara>=2.8.0",
21+
"aeppl>=0.0.35",
22+
"aehmc>=0.0.9",
2323
"polyagamma>=1.3.2",
2424
"cons",
2525
"logical-unification",

tests/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from scipy.linalg import toeplitz
1010

1111
from aemcmc.basic import construct_sampler
12-
from aemcmc.opt import SubsumingElemwise
12+
from aemcmc.rewriting import SubsumingElemwise
1313

1414

1515
def test_closed_form_posterior():

tests/test_conjugates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import aesara
22
import aesara.tensor as at
33
import pytest
4-
from aesara.graph.unify import eval_if_etuple
4+
from aesara.graph.rewriting.unify import eval_if_etuple
55
from aesara.tensor.random import RandomStream
66
from kanren import run
77
from unification import var

tests/test_gibbs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
import scipy.special
66
from aesara.graph.basic import equal_computations
7-
from aesara.graph.opt_utils import optimize_graph
7+
from aesara.graph.rewriting.utils import rewrite_graph
88
from aesara.tensor.random.utils import RandomStream
99
from scipy.linalg import toeplitz
1010

@@ -19,7 +19,7 @@
1919
normal_regression_posterior,
2020
sample_CRT,
2121
)
22-
from aemcmc.opt import SamplerTracker, construct_ir_fgraph, sampler_rewrites_db
22+
from aemcmc.rewriting import SamplerTracker, construct_ir_fgraph, sampler_rewrites_db
2323

2424

2525
@pytest.fixture
@@ -233,7 +233,7 @@ def test_bern_sigmoid_dot_match(srng):
233233
p = at.sigmoid(-eta)
234234
Y_rv = srng.bernoulli(p)
235235

236-
Y_rv = optimize_graph(Y_rv)
236+
Y_rv = rewrite_graph(Y_rv)
237237

238238
assert bern_sigmoid_dot_match(Y_rv)
239239

@@ -289,7 +289,7 @@ def test_gamma_match(srng):
289289
b = at.scalar("b")
290290
beta_rv = srng.gamma(a, b)
291291

292-
beta_rv = optimize_graph(beta_rv)
292+
beta_rv = rewrite_graph(beta_rv)
293293

294294
a_m, b_m = gamma_match(beta_rv)
295295

@@ -326,7 +326,7 @@ def test_nbinom_logistic_horseshoe_finders():
326326

327327
fgraph.attach_feature(SamplerTracker(srng))
328328

329-
_ = sampler_rewrites_db.query("+basic").optimize(fgraph)
329+
_ = sampler_rewrites_db.query("+basic").rewrite(fgraph)
330330

331331
discovered_samplers = fgraph.sampler_mappings.rvs_to_samplers
332332
discovered_samplers = {
@@ -366,7 +366,7 @@ def test_bern_logistic_horseshoe_finders():
366366

367367
fgraph.attach_feature(SamplerTracker(srng))
368368

369-
_ = sampler_rewrites_db.query("+basic").optimize(fgraph)
369+
_ = sampler_rewrites_db.query("+basic").rewrite(fgraph)
370370

371371
discovered_samplers = fgraph.sampler_mappings.rvs_to_samplers
372372
discovered_samplers = {

0 commit comments

Comments
 (0)