Skip to content

Commit 6cdfc30

Browse files
committed
Rename compile_pymc to compile
1 parent a714b24 commit 6cdfc30

File tree

21 files changed

+88
-81
lines changed

21 files changed

+88
-81
lines changed

docs/source/api/pytensorf.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ PyTensor utils
66
.. autosummary::
77
:toctree: generated/
88

9-
compile_pymc
9+
compile
1010
gradient
1111
hessian
1212
hessian_diag
@@ -19,6 +19,4 @@ PyTensor utils
1919
CallableTensor
2020
join_nonshared_inputs
2121
make_shared_replacements
22-
generator
23-
convert_generator_data
2422
convert_data

pymc/backends/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
from pymc.backends.report import SamplerReport
3636
from pymc.model import modelcontext
37-
from pymc.pytensorf import compile_pymc
37+
from pymc.pytensorf import compile
3838
from pymc.util import get_var_name
3939

4040
logger = logging.getLogger(__name__)
@@ -171,7 +171,7 @@ def __init__(
171171

172172
if fn is None:
173173
# borrow=True avoids deepcopy when inputs=output which is the case for untransformed value variables
174-
fn = compile_pymc(
174+
fn = compile(
175175
inputs=[pytensor.In(v, borrow=True) for v in model.value_vars],
176176
outputs=[pytensor.Out(v, borrow=True) for v in vars],
177177
on_unused_input="ignore",

pymc/func_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,18 +169,18 @@ def find_constrained_prior(
169169
)
170170

171171
target = (pt.exp(logcdf_lower) - mass_below_lower) ** 2
172-
target_fn = pm.pytensorf.compile_pymc([dist_params], target, allow_input_downcast=True)
172+
target_fn = pm.pytensorf.compile([dist_params], target, allow_input_downcast=True)
173173

174174
constraint = pt.exp(logcdf_upper) - pt.exp(logcdf_lower)
175-
constraint_fn = pm.pytensorf.compile_pymc([dist_params], constraint, allow_input_downcast=True)
175+
constraint_fn = pm.pytensorf.compile([dist_params], constraint, allow_input_downcast=True)
176176

177177
jac: str | Callable
178178
constraint_jac: str | Callable
179179
try:
180180
pytensor_jac = pm.gradient(target, [dist_params])
181-
jac = pm.pytensorf.compile_pymc([dist_params], pytensor_jac, allow_input_downcast=True)
181+
jac = pm.pytensorf.compile([dist_params], pytensor_jac, allow_input_downcast=True)
182182
pytensor_constraint_jac = pm.gradient(constraint, [dist_params])
183-
constraint_jac = pm.pytensorf.compile_pymc(
183+
constraint_jac = pm.pytensorf.compile(
184184
[dist_params], pytensor_constraint_jac, allow_input_downcast=True
185185
)
186186
# when PyMC cannot compute the gradient

pymc/gp/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from scipy.cluster.vq import kmeans
2424

2525
from pymc.model.core import modelcontext
26-
from pymc.pytensorf import compile_pymc
26+
from pymc.pytensorf import compile
2727

2828
JITTER_DEFAULT = 1e-6
2929

@@ -55,7 +55,7 @@ def replace_with_values(vars_needed, replacements=None, model=None):
5555
if len(inputs) == 0:
5656
return tuple(v.eval() for v in vars_needed)
5757

58-
fn = compile_pymc(
58+
fn = compile(
5959
inputs,
6060
vars_needed,
6161
allow_input_downcast=True,

pymc/initial_point.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from pymc.logprob.transforms import Transform
2828
from pymc.pytensorf import (
29-
compile_pymc,
29+
compile,
3030
find_rng_nodes,
3131
replace_rng_nodes,
3232
reseed_rngs,
@@ -157,7 +157,7 @@ def make_initial_point_fn(
157157
# Replace original rng shared variables so that we don't mess with them
158158
# when calling the final seeded function
159159
initial_values = replace_rng_nodes(initial_values)
160-
func = compile_pymc(inputs=[], outputs=initial_values, mode=pytensor.compile.mode.FAST_COMPILE)
160+
func = compile(inputs=[], outputs=initial_values, mode=pytensor.compile.mode.FAST_COMPILE)
161161

162162
varnames = []
163163
for var in model.free_RVs:

pymc/model/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from pymc.pytensorf import (
5757
PointFunc,
5858
SeedSequenceSeed,
59-
compile_pymc,
59+
compile,
6060
convert_observed_data,
6161
gradient,
6262
hessian,
@@ -253,7 +253,7 @@ def __init__(
253253
)
254254
inputs = grad_vars
255255

256-
self._pytensor_function = compile_pymc(inputs, outputs, givens=givens, **kwargs)
256+
self._pytensor_function = compile(inputs, outputs, givens=givens, **kwargs)
257257
self._raveled_inputs = ravel_inputs
258258

259259
def set_weights(self, values):
@@ -1637,7 +1637,7 @@ def compile_fn(
16371637
inputs = inputvars(outs)
16381638

16391639
with self:
1640-
fn = compile_pymc(
1640+
fn = compile(
16411641
inputs,
16421642
outs,
16431643
allow_input_downcast=True,

pymc/pytensorf.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060

6161
__all__ = [
6262
"CallableTensor",
63+
"compile",
6364
"compile_pymc",
6465
"cont_inputs",
6566
"convert_data",
@@ -981,7 +982,7 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
981982
return rng_updates
982983

983984

984-
def compile_pymc(
985+
def compile(
985986
inputs,
986987
outputs,
987988
random_seed: SeedSequenceSeed = None,
@@ -990,7 +991,7 @@ def compile_pymc(
990991
) -> Function:
991992
"""Use ``pytensor.function`` with specialized pymc rewrites always enabled.
992993
993-
This function also ensures shared RandomState/Generator used by RandomVariables
994+
This function also ensures shared Generator used by RandomVariables
994995
in the graph are updated across calls, to ensure independent draws.
995996
996997
Parameters
@@ -1061,6 +1062,14 @@ def compile_pymc(
10611062
return pytensor_function
10621063

10631064

1065+
def compile_pymc(*args, **kwargs):
1066+
warnings.warn(
1067+
"compile_pymc was renamed to compile. Old name will be removed in a future release of PyMC",
1068+
FutureWarning,
1069+
)
1070+
return compile(*args, **kwargs)
1071+
1072+
10641073
def constant_fold(
10651074
xs: Sequence[TensorVariable], raise_not_constant: bool = True
10661075
) -> tuple[np.ndarray | Variable, ...]:

pymc/sampling/forward.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from pymc.backends.base import MultiTrace
5252
from pymc.blocking import PointType
5353
from pymc.model import Model, modelcontext
54-
from pymc.pytensorf import compile_pymc
54+
from pymc.pytensorf import compile
5555
from pymc.util import (
5656
CustomProgress,
5757
RandomState,
@@ -273,7 +273,7 @@ def expand(node):
273273
]
274274

275275
return (
276-
compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs),
276+
compile(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs),
277277
set(basic_rvs) & (volatile_nodes - set(givens_dict)), # Basic RVs that will be resampled
278278
)
279279

@@ -329,7 +329,7 @@ def draw(
329329
if random_seed is not None:
330330
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
331331

332-
draw_fn = compile_pymc(inputs=[], outputs=vars, random_seed=random_seed, **kwargs)
332+
draw_fn = compile(inputs=[], outputs=vars, random_seed=random_seed, **kwargs)
333333

334334
if draws == 1:
335335
return draw_fn()

pymc/smc/kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from pymc.initial_point import make_initial_point_expression
3131
from pymc.model import Point, modelcontext
3232
from pymc.pytensorf import (
33-
compile_pymc,
33+
compile,
3434
floatX,
3535
join_nonshared_inputs,
3636
make_shared_replacements,
@@ -636,6 +636,6 @@ def _logp_forw(point, out_vars, in_vars, shared):
636636
out_list, inarray0 = join_nonshared_inputs(
637637
point=point, outputs=out_vars, inputs=in_vars, shared_inputs=shared
638638
)
639-
f = compile_pymc([inarray0], out_list[0])
639+
f = compile([inarray0], out_list[0])
640640
f.trust_input = True
641641
return f

pymc/step_methods/metropolis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from pymc.initial_point import PointType
3232
from pymc.pytensorf import (
3333
CallableTensor,
34-
compile_pymc,
34+
compile,
3535
floatX,
3636
join_nonshared_inputs,
3737
replace_rng_nodes,
@@ -1241,6 +1241,6 @@ def delta_logp(
12411241

12421242
if compile_kwargs is None:
12431243
compile_kwargs = {}
1244-
f = compile_pymc([inarray1, inarray0], logp1 - logp0, **compile_kwargs)
1244+
f = compile([inarray1, inarray0], logp1 - logp0, **compile_kwargs)
12451245
f.trust_input = True
12461246
return f

0 commit comments

Comments
 (0)