File tree Expand file tree Collapse file tree 1 file changed +1
-13
lines changed Expand file tree Collapse file tree 1 file changed +1
-13
lines changed Original file line number Diff line number Diff line change 3030from arviz .data .base import make_attrs
3131from jax .lax import scan
3232from numpy .typing import ArrayLike
33- from pytensor .compile import SharedVariable , Supervisor , mode
33+ from pytensor .compile import SharedVariable , mode
3434from pytensor .graph .basic import graph_inputs
3535from pytensor .graph .fg import FunctionGraph
3636from pytensor .graph .replace import clone_replace
@@ -127,18 +127,6 @@ def get_jaxified_graph(
127127 graph = _replace_shared_variables (outputs ) if outputs is not None else None
128128
129129 fgraph = FunctionGraph (inputs = inputs , outputs = graph , clone = True )
130- # We need to add a Supervisor to the fgraph to be able to run the
131- # JAX sequential optimizer without warnings. We made sure there
132- # are no mutable input variables, so we only need to check for
133- # "destroyers". This should be automatically handled by PyTensor
134- # once https://github.com/aesara-devs/aesara/issues/637 is fixed.
135- fgraph .attach_feature (
136- Supervisor (
137- input
138- for input in fgraph .inputs
139- if not (hasattr (fgraph , "destroyers" ) and fgraph .has_destroyers ([input ]))
140- )
141- )
142130 mode .JAX .optimizer .rewrite (fgraph )
143131
144132 # We now jaxify the optimized fgraph
You can’t perform that action at this time.
0 commit comments