1+ import abc
12import itertools
23import operator
34import sys
45from collections import defaultdict , deque
5- from collections .abc import Generator
6+ from collections .abc import Generator , Sequence
67from functools import cache , reduce
78from typing import TypeVar
89from warnings import warn
1213from pytensor .compile .function .types import Supervisor
1314from pytensor .compile .mode import get_target_language
1415from pytensor .configdefaults import config
15- from pytensor .graph import FunctionGraph
16+ from pytensor .graph import FunctionGraph , Op
1617from pytensor .graph .basic import Apply , Variable , ancestors
1718from pytensor .graph .destroyhandler import DestroyHandler , inplace_candidates
1819from pytensor .graph .features import ReplaceValidate
4748from pytensor .tensor .variable import TensorConstant , TensorVariable
4849
4950
50- class InplaceElemwiseOptimizer (GraphRewriter ):
51+ class InplaceGraphOptimizer (GraphRewriter ):
5152 r"""
5253 This is parameterized so that it works for `Elemwise` `Op`\s.
5354 """
5455
56+ op : type [Op ]
57+
5558 def add_requirements (self , fgraph ):
5659 fgraph .attach_feature (DestroyHandler ())
5760
61+ @abc .abstractmethod
62+ def filter_candidate_pairs (
63+ self , fgraph : FunctionGraph , node : Apply , protected_inputs : Sequence [Variable ]
64+ ) -> Sequence [tuple [tuple [int , Variable ], tuple [int , Variable ]]]:
65+ pass
66+
67+ @abc .abstractmethod
68+ def create_inplace_node (
69+ self , node : Apply , inplace_pattern : dict [int , Sequence [int ]]
70+ ) -> Apply :
71+ pass
72+
5873 def apply (self , fgraph ):
5974 r"""
6075
@@ -93,30 +108,6 @@ def apply(self, fgraph):
93108 # tackle them in a more general way. The whole try/except approach is probably suboptimal.
94109 # We can consider restricting inputs with static shapes that are large enough.
95110
96- def create_inplace_node (node , inplace_pattern ):
97- op = node .op
98- scalar_op = op .scalar_op
99- inplace_pattern = {i : o for i , [o ] in inplace_pattern .items ()}
100- if hasattr (scalar_op , "make_new_inplace" ):
101- new_scalar_op = scalar_op .make_new_inplace (
102- ps .transfer_type (
103- * [
104- inplace_pattern .get (i , o .dtype )
105- for i , o in enumerate (node .outputs )
106- ]
107- )
108- )
109- else :
110- new_scalar_op = type (scalar_op )(
111- ps .transfer_type (
112- * [
113- inplace_pattern .get (i , None )
114- for i in range (len (node .outputs ))
115- ]
116- )
117- )
118- return type (op )(new_scalar_op , inplace_pattern ).make_node (* node .inputs )
119-
120111 if config .tensor__insert_inplace_optimizer_validate_nb != - 1 :
121112 warn (
122113 "tensor__insert_inplace_optimizer_validate_nb config is deprecated. Setting it will fail in a future release." ,
@@ -140,43 +131,30 @@ def create_inplace_node(node, inplace_pattern):
140131 protected_inputs .update (fgraph .outputs )
141132 root_destroyer = fgraph .destroy_handler .root_destroyer
142133
134+ self_op = self .op
143135 update_mapping = fgraph .update_mapping or {}
144136 op_updates : dict [TensorVariable , TensorVariable ] = {
145137 out : fgraph .inputs [update_mapping [out_idx ]]
146138 for out_idx , out in enumerate (fgraph .outputs )
147139 if (
148140 out_idx in update_mapping
149141 and out .owner
150- and isinstance (out .owner .op , Elemwise )
142+ and isinstance (out .owner .op , self_op )
151143 )
152144 }
153145 set_op_updates = set (op_updates .keys ())
154146
155147 for node in fgraph .toposort ():
156- if not isinstance (node .op , Elemwise ) or node .op .destroy_map :
148+ if not isinstance (node .op , self_op ) or node .op .destroy_map :
157149 continue
158150
159151 # If big graph and the outputs are scalar, do not make it inplace.
160152 if large_graph and all (node .outputs [0 ].type .broadcastable ):
161153 continue
162154
163- candidate_inputs = [
164- (node .inputs .index (inp ), inp )
165- for inp in inplace_candidates (
166- fgraph ,
167- node .inputs ,
168- protected_inputs = protected_inputs ,
169- )
170- ]
171- if not candidate_inputs :
172- return []
173-
174- candidate_pairs = [
175- ((o , out ), (i , inp ))
176- for o , out in enumerate (node .outputs )
177- for i , inp in candidate_inputs
178- if inp .type == out .type
179- ]
155+ candidate_pairs = self .filter_candidate_pairs (
156+ fgraph , node , protected_inputs
157+ )
180158
181159 if not candidate_pairs :
182160 continue
@@ -216,7 +194,7 @@ def create_inplace_node(node, inplace_pattern):
216194 inplace_pattern [o ] = [i ]
217195 tried_inputs .add (i )
218196
219- inplace_node = create_inplace_node (node , inplace_pattern )
197+ inplace_node = self . create_inplace_node (node , inplace_pattern )
220198 if inplace_node .op .destroy_map == inplace_pattern :
221199 replacements = tuple (zip (node .outputs , inplace_node .outputs ))
222200 try :
@@ -238,7 +216,7 @@ def create_inplace_node(node, inplace_pattern):
238216 inplace_pattern [o ] = [i ]
239217 tried_inputs .add (i )
240218
241- inplace_node = create_inplace_node (node , inplace_pattern )
219+ inplace_node = self . create_inplace_node (node , inplace_pattern )
242220 if inplace_node .op .destroy_map != inplace_pattern :
243221 # This Op can't respect this partial inplace pattern,
244222 # We assume it can't support any other cases
@@ -277,6 +255,50 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1):
277255 )
278256
279257
258+ class InplaceElemwiseOptimizer (InplaceGraphOptimizer ):
259+ op = Elemwise
260+
261+ def filter_candidate_pairs (self , fgraph , node , protected_inputs ):
262+ candidate_inputs = [
263+ (node .inputs .index (inp ), inp )
264+ for inp in inplace_candidates (
265+ fgraph ,
266+ node .inputs ,
267+ protected_inputs = protected_inputs ,
268+ )
269+ ]
270+ if not candidate_inputs :
271+ return []
272+
273+ return [
274+ ((o , out ), (i , inp ))
275+ for o , out in enumerate (node .outputs )
276+ for i , inp in candidate_inputs
277+ if inp .type == out .type
278+ ]
279+
280+ def create_inplace_node (self , node , inplace_pattern ):
281+ op = node .op
282+ scalar_op = op .scalar_op
283+ inplace_pattern = {i : o for i , [o ] in inplace_pattern .items ()}
284+ if hasattr (scalar_op , "make_new_inplace" ):
285+ new_scalar_op = scalar_op .make_new_inplace (
286+ ps .transfer_type (
287+ * [
288+ inplace_pattern .get (i , o .dtype )
289+ for i , o in enumerate (node .outputs )
290+ ]
291+ )
292+ )
293+ else :
294+ new_scalar_op = type (scalar_op )(
295+ ps .transfer_type (
296+ * [inplace_pattern .get (i , None ) for i in range (len (node .outputs ))]
297+ )
298+ )
299+ return type (op )(new_scalar_op , inplace_pattern ).make_node (* node .inputs )
300+
301+
280302compile .optdb .register (
281303 "inplace_elemwise" ,
282304 InplaceElemwiseOptimizer (),
0 commit comments