Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from arviz.data.base import make_attrs
from jax.lax import scan
from numpy.typing import ArrayLike
from pytensor.compile import SharedVariable, Supervisor, mode
from pytensor.compile import SharedVariable, mode
from pytensor.graph.basic import graph_inputs
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace
Expand Down Expand Up @@ -127,18 +127,6 @@ def get_jaxified_graph(
graph = _replace_shared_variables(outputs) if outputs is not None else None

fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
# We need to add a Supervisor to the fgraph to be able to run the
# JAX sequential optimizer without warnings. We made sure there
# are no mutable input variables, so we only need to check for
# "destroyers". This should be automatically handled by PyTensor
# once https://github.com/aesara-devs/aesara/issues/637 is fixed.
fgraph.attach_feature(
Supervisor(
input
for input in fgraph.inputs
if not (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))
)
)
mode.JAX.optimizer.rewrite(fgraph)

# We now jaxify the optimized fgraph
Expand Down
Loading