1+ import abc
12import itertools
23import operator
34import sys
45from collections import defaultdict , deque
5- from collections .abc import Generator
6+ from collections .abc import Generator , Sequence
67from functools import cache , reduce
78from typing import TypeVar
89from warnings import warn
1213from pytensor .compile .function .types import Supervisor
1314from pytensor .compile .mode import get_target_language
1415from pytensor .configdefaults import config
15- from pytensor .graph import FunctionGraph
16+ from pytensor .graph import FunctionGraph , Op
1617from pytensor .graph .basic import Apply , Variable , ancestors
1718from pytensor .graph .destroyhandler import DestroyHandler , inplace_candidates
1819from pytensor .graph .features import ReplaceValidate
4748from pytensor .tensor .variable import TensorConstant , TensorVariable
4849
4950
50- class InplaceElemwiseOptimizer (GraphRewriter ):
51- r"""
52- This is parameterized so that it works for `Elemwise` `Op`\s.
53- """
51+ class InplaceGraphOptimizer (GraphRewriter ):
52+ op : type [Op ]
5453
5554 def add_requirements (self , fgraph ):
5655 fgraph .attach_feature (DestroyHandler ())
5756
57+ @abc .abstractmethod
58+ def filter_candidate_pairs (
59+ self , fgraph : FunctionGraph , node : Apply , protected_inputs : Sequence [Variable ]
60+ ) -> Sequence [tuple [tuple [int , Variable ], tuple [int , Variable ]]]:
61+ pass
62+
63+ @abc .abstractmethod
64+ def create_inplace_node (
65+ self , node : Apply , inplace_pattern : dict [int , Sequence [int ]]
66+ ) -> Apply :
67+ pass
68+
5869 def apply (self , fgraph ):
5970 r"""
6071
61- Attempts to replace all `Elemwise`\s by versions of them that operate
62- inplace. It operates greedily: for each `Elemwise` that is encountered,
63- for each output, it tries each input to see if it can operate inplace
64- on that input. If so, it makes the change and goes to the next output
65- or `Elemwise`.
72+ Attempts to replace all `Op`\s by versions of them that operate
73+ inplace. It operates greedily: for each `Op` that is encountered,
74+ it tries to inplace all the valid inputs at once (if the Op supports it),
75+ if that fails, it tries to inplace one input at a time.
6676
6777 Examples
6878 --------
@@ -93,36 +103,13 @@ def apply(self, fgraph):
93103 # tackle them in a more general way. The whole try/except approach is probably suboptimal.
94104 # We can consider restricting inputs with static shapes that are large enough.
95105
96- def create_inplace_node (node , inplace_pattern ):
97- op = node .op
98- scalar_op = op .scalar_op
99- inplace_pattern = {i : o for i , [o ] in inplace_pattern .items ()}
100- if hasattr (scalar_op , "make_new_inplace" ):
101- new_scalar_op = scalar_op .make_new_inplace (
102- ps .transfer_type (
103- * [
104- inplace_pattern .get (i , o .dtype )
105- for i , o in enumerate (node .outputs )
106- ]
107- )
108- )
109- else :
110- new_scalar_op = type (scalar_op )(
111- ps .transfer_type (
112- * [
113- inplace_pattern .get (i , None )
114- for i in range (len (node .outputs ))
115- ]
116- )
117- )
118- return type (op )(new_scalar_op , inplace_pattern ).make_node (* node .inputs )
119-
120106 if config .tensor__insert_inplace_optimizer_validate_nb != - 1 :
121107 warn (
122108 "tensor__insert_inplace_optimizer_validate_nb config is deprecated. Setting it will fail in a future release." ,
123109 FutureWarning ,
124110 )
125111
112+ reason = f"{ self .op } _inplace_optimizer"
126113 prof = {
127114 "opt" : self ,
128115 "node_before" : len (fgraph .apply_nodes ),
@@ -140,43 +127,30 @@ def create_inplace_node(node, inplace_pattern):
140127 protected_inputs .update (fgraph .outputs )
141128 root_destroyer = fgraph .destroy_handler .root_destroyer
142129
130+ self_op = self .op
143131 update_mapping = fgraph .update_mapping or {}
144132 op_updates : dict [TensorVariable , TensorVariable ] = {
145133 out : fgraph .inputs [update_mapping [out_idx ]]
146134 for out_idx , out in enumerate (fgraph .outputs )
147135 if (
148136 out_idx in update_mapping
149137 and out .owner
150- and isinstance (out .owner .op , Elemwise )
138+ and isinstance (out .owner .op , self_op )
151139 )
152140 }
153141 set_op_updates = set (op_updates .keys ())
154142
155143 for node in fgraph .toposort ():
156- if not isinstance (node .op , Elemwise ) or node .op .destroy_map :
144+ if not isinstance (node .op , self_op ) or node .op .destroy_map :
157145 continue
158146
159147 # If big graph and the outputs are scalar, do not make it inplace.
160148 if large_graph and all (node .outputs [0 ].type .broadcastable ):
161149 continue
162150
163- candidate_inputs = [
164- (node .inputs .index (inp ), inp )
165- for inp in inplace_candidates (
166- fgraph ,
167- node .inputs ,
168- protected_inputs = protected_inputs ,
169- )
170- ]
171- if not candidate_inputs :
172- return []
173-
174- candidate_pairs = [
175- ((o , out ), (i , inp ))
176- for o , out in enumerate (node .outputs )
177- for i , inp in candidate_inputs
178- if inp .type == out .type
179- ]
151+ candidate_pairs = self .filter_candidate_pairs (
152+ fgraph , node , protected_inputs
153+ )
180154
181155 if not candidate_pairs :
182156 continue
@@ -216,13 +190,11 @@ def create_inplace_node(node, inplace_pattern):
216190 inplace_pattern [o ] = [i ]
217191 tried_inputs .add (i )
218192
219- inplace_node = create_inplace_node (node , inplace_pattern )
193+ inplace_node = self . create_inplace_node (node , inplace_pattern )
220194 if inplace_node .op .destroy_map == inplace_pattern :
221195 replacements = tuple (zip (node .outputs , inplace_node .outputs ))
222196 try :
223- fgraph .replace_all_validate (
224- replacements , reason = "inplace_elemwise_optimizer"
225- )
197+ fgraph .replace_all_validate (replacements , reason = reason )
226198 except InconsistencyError :
227199 prof ["nb_eager_inconsistent" ] += 1
228200 else :
@@ -238,17 +210,15 @@ def create_inplace_node(node, inplace_pattern):
238210 inplace_pattern [o ] = [i ]
239211 tried_inputs .add (i )
240212
241- inplace_node = create_inplace_node (node , inplace_pattern )
213+ inplace_node = self . create_inplace_node (node , inplace_pattern )
242214 if inplace_node .op .destroy_map != inplace_pattern :
243215 # This Op can't respect this partial inplace pattern,
244216 # We assume it can't support any other cases
245217 break
246218 else :
247219 replacements = tuple (zip (node .outputs , inplace_node .outputs ))
248220 try :
249- fgraph .replace_all_validate (
250- replacements , reason = "inplace_elemwise_optimizer"
251- )
221+ fgraph .replace_all_validate (replacements , reason = reason )
252222 node = inplace_node
253223 replaced = True
254224 except InconsistencyError :
@@ -278,6 +248,50 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1):
278248 )
279249
280250
251+ class InplaceElemwiseOptimizer (InplaceGraphOptimizer ):
252+ op = Elemwise
253+
254+ def filter_candidate_pairs (self , fgraph , node , protected_inputs ):
255+ candidate_inputs = [
256+ (node .inputs .index (inp ), inp )
257+ for inp in inplace_candidates (
258+ fgraph ,
259+ node .inputs ,
260+ protected_inputs = protected_inputs ,
261+ )
262+ ]
263+ if not candidate_inputs :
264+ return []
265+
266+ return [
267+ ((o , out ), (i , inp ))
268+ for o , out in enumerate (node .outputs )
269+ for i , inp in candidate_inputs
270+ if inp .type == out .type
271+ ]
272+
273+ def create_inplace_node (self , node , inplace_pattern ):
274+ op = node .op
275+ scalar_op = op .scalar_op
276+ inplace_pattern = {i : o for i , [o ] in inplace_pattern .items ()}
277+ if hasattr (scalar_op , "make_new_inplace" ):
278+ new_scalar_op = scalar_op .make_new_inplace (
279+ ps .transfer_type (
280+ * [
281+ inplace_pattern .get (i , o .dtype )
282+ for i , o in enumerate (node .outputs )
283+ ]
284+ )
285+ )
286+ else :
287+ new_scalar_op = type (scalar_op )(
288+ ps .transfer_type (
289+ * [inplace_pattern .get (i , None ) for i in range (len (node .outputs ))]
290+ )
291+ )
292+ return type (op )(new_scalar_op , inplace_pattern ).make_node (* node .inputs )
293+
294+
281295compile .optdb .register (
282296 "inplace_elemwise" ,
283297 InplaceElemwiseOptimizer (),
0 commit comments