66from functools import partial
77from typing import Union , cast
88
9- from pytensor .compile . function import function
9+ from pytensor .compile import get_default_mode , insert_deepcopy
1010from pytensor .compile .function .pfunc import rebuild_collect_shared
11+ from pytensor .compile .function .types import add_supervisor_to_fgraph
12+ from pytensor .compile .io import In , Out
1113from pytensor .compile .sharedvalue import SharedVariable
1214from pytensor .configdefaults import config
1315from pytensor .gradient import DisconnectedType , Rop , grad
@@ -433,6 +435,7 @@ def __init__(
433435 assert isinstance (name , str ), "name must be None or string object"
434436 self .name = name
435437 self .destroy_map = destroy_map if destroy_map is not None else {}
438+ self ._prepared_fgraph = None
436439
437440 def __eq__ (self , other ):
438441 # TODO: recognize a copy
@@ -847,16 +850,48 @@ def infer_shape(self, fgraph, node, shapes):
847850
848851 return ret
849852
853+ def _prepare_fgraph (self , impl ):
854+ if self ._prepared_fgraph is None :
855+ mode = get_default_mode ()
856+ if impl == "py" :
857+ mode = mode .excluding ("cxx" )
858+ rewriter = mode .optimizer
859+
860+ fgraph = self .fgraph
861+ wrapped_inputs = [
862+ In (inp , borrow = False , mutable = False ) for inp in self .fgraph .inputs
863+ ]
864+ wrapped_outputs = [Out (out , borrow = True ) for out in self .fgraph .outputs ]
865+ add_supervisor_to_fgraph (
866+ fgraph ,
867+ wrapped_inputs ,
868+ accept_inplace = False ,
869+ )
870+ rewriter (fgraph )
871+ insert_deepcopy (fgraph , wrapped_inputs , wrapped_outputs )
872+ self ._prepared_fgraph = fgraph
873+
874+ return self ._prepared_fgraph
875+
850876 @property
851877 def fn (self ):
852878 """Lazily compile the inner function graph."""
853- if getattr (self , "_fn" , None ) is not None :
854- return self ._fn
855-
856- self ._fn = function (self .inner_inputs , self .inner_outputs , ** self .kwargs )
857- self ._fn .trust_input = True
858-
859- return self ._fn
879+ return None
880+ # if getattr(self, "_fn", None) is not None:
881+ # return self._fn
882+ #
883+ # self._fn = pfunc(
884+ # wrapped_inputs,
885+ # wrapped_outputs,
886+ # mode=mode_instance,
887+ # accept_inplace=True,
888+ # on_unused_input="ignore",
889+ # fgraph=self.fgraph,
890+ # )
891+ # self._fn = function(self.inner_inputs, self.inner_outputs, **self.kwargs)
892+ # self._fn.trust_input = True
893+ #
894+ # return self._fn
860895
861896 @property
862897 def inner_inputs (self ):
@@ -875,11 +910,7 @@ def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
875910 from pytensor .link .c .basic import CLinker
876911 from pytensor .link .vm import VMLinker
877912
878- # FIXME: Don't call self.fn just to get the optimized fgraph
879- fg = self .fn .maker .fgraph
880- # fg = self.fgraph
881- # rewriter = get_default_mode().optimizer
882- # rewriter(fg)
913+ fg = self ._prepare_fgraph (impl )
883914 fg_no_recycling = [
884915 new_o
885916 for (new_o , old_o ) in zip (fg .outputs , node .outputs , strict = True )
@@ -890,8 +921,8 @@ def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
890921 node_output_storage = [storage_map [r ] for r in node .outputs ]
891922
892923 def create_thunk (linker ):
893- linker .accept (fg , no_recycling = fg_no_recycling )
894- thunk , _ , _ = linker .make_thunk (
924+ linker .accept (fg . clone () , no_recycling = fg_no_recycling )
925+ thunk , i , o = linker .make_thunk (
895926 input_storage = node_input_storage , output_storage = node_output_storage
896927 )
897928
0 commit comments