1
+ import abc
1
2
import itertools
2
3
import operator
3
4
import sys
4
5
from collections import defaultdict , deque
5
- from collections .abc import Generator
6
+ from collections .abc import Generator , Sequence
6
7
from functools import cache , reduce
7
8
from typing import TypeVar
8
9
from warnings import warn
12
13
from pytensor .compile .function .types import Supervisor
13
14
from pytensor .compile .mode import get_target_language
14
15
from pytensor .configdefaults import config
15
- from pytensor .graph import FunctionGraph
16
+ from pytensor .graph import FunctionGraph , Op
16
17
from pytensor .graph .basic import Apply , Variable , ancestors
17
18
from pytensor .graph .destroyhandler import DestroyHandler , inplace_candidates
18
19
from pytensor .graph .features import ReplaceValidate
47
48
from pytensor .tensor .variable import TensorConstant , TensorVariable
48
49
49
50
50
- class InplaceElemwiseOptimizer (GraphRewriter ):
51
- r"""
52
- This is parameterized so that it works for `Elemwise` `Op`\s.
53
- """
51
+ class InplaceGraphOptimizer (GraphRewriter ):
52
+ op : type [Op ]
54
53
55
54
def add_requirements (self , fgraph ):
56
55
fgraph .attach_feature (DestroyHandler ())
57
56
57
+ @abc .abstractmethod
58
+ def filter_candidate_pairs (
59
+ self , fgraph : FunctionGraph , node : Apply , protected_inputs : Sequence [Variable ]
60
+ ) -> Sequence [tuple [tuple [int , Variable ], tuple [int , Variable ]]]:
61
+ pass
62
+
63
+ @abc .abstractmethod
64
+ def create_inplace_node (
65
+ self , node : Apply , inplace_pattern : dict [int , Sequence [int ]]
66
+ ) -> Apply :
67
+ pass
68
+
58
69
def apply (self , fgraph ):
59
70
r"""
60
71
61
- Attempts to replace all `Elemwise`\s by versions of them that operate
62
- inplace. It operates greedily: for each `Elemwise` that is encountered,
63
- for each output, it tries each input to see if it can operate inplace
64
- on that input. If so, it makes the change and goes to the next output
65
- or `Elemwise`.
72
+ Attempts to replace all `Op`\s by versions of them that operate
73
+ inplace. It operates greedily: for each `Op` that is encountered,
74
+ it tries to inplace all the valid inputs at once (if the Op supports it),
75
+ if that fails, it tries to inplace one input at a time.
66
76
67
77
Examples
68
78
--------
@@ -93,36 +103,13 @@ def apply(self, fgraph):
93
103
# tackle them in a more general way. The whole try/except approach is probably suboptimal.
94
104
# We can consider restricting inputs with static shapes that are large enough.
95
105
96
- def create_inplace_node (node , inplace_pattern ):
97
- op = node .op
98
- scalar_op = op .scalar_op
99
- inplace_pattern = {i : o for i , [o ] in inplace_pattern .items ()}
100
- if hasattr (scalar_op , "make_new_inplace" ):
101
- new_scalar_op = scalar_op .make_new_inplace (
102
- ps .transfer_type (
103
- * [
104
- inplace_pattern .get (i , o .dtype )
105
- for i , o in enumerate (node .outputs )
106
- ]
107
- )
108
- )
109
- else :
110
- new_scalar_op = type (scalar_op )(
111
- ps .transfer_type (
112
- * [
113
- inplace_pattern .get (i , None )
114
- for i in range (len (node .outputs ))
115
- ]
116
- )
117
- )
118
- return type (op )(new_scalar_op , inplace_pattern ).make_node (* node .inputs )
119
-
120
106
if config .tensor__insert_inplace_optimizer_validate_nb != - 1 :
121
107
warn (
122
108
"tensor__insert_inplace_optimizer_validate_nb config is deprecated. Setting it will fail in a future release." ,
123
109
FutureWarning ,
124
110
)
125
111
112
+ reason = f"{ self .op } _inplace_optimizer"
126
113
prof = {
127
114
"opt" : self ,
128
115
"node_before" : len (fgraph .apply_nodes ),
@@ -140,43 +127,30 @@ def create_inplace_node(node, inplace_pattern):
140
127
protected_inputs .update (fgraph .outputs )
141
128
root_destroyer = fgraph .destroy_handler .root_destroyer
142
129
130
+ self_op = self .op
143
131
update_mapping = fgraph .update_mapping or {}
144
132
op_updates : dict [TensorVariable , TensorVariable ] = {
145
133
out : fgraph .inputs [update_mapping [out_idx ]]
146
134
for out_idx , out in enumerate (fgraph .outputs )
147
135
if (
148
136
out_idx in update_mapping
149
137
and out .owner
150
- and isinstance (out .owner .op , Elemwise )
138
+ and isinstance (out .owner .op , self_op )
151
139
)
152
140
}
153
141
set_op_updates = set (op_updates .keys ())
154
142
155
143
for node in fgraph .toposort ():
156
- if not isinstance (node .op , Elemwise ) or node .op .destroy_map :
144
+ if not isinstance (node .op , self_op ) or node .op .destroy_map :
157
145
continue
158
146
159
147
# If big graph and the outputs are scalar, do not make it inplace.
160
148
if large_graph and all (node .outputs [0 ].type .broadcastable ):
161
149
continue
162
150
163
- candidate_inputs = [
164
- (node .inputs .index (inp ), inp )
165
- for inp in inplace_candidates (
166
- fgraph ,
167
- node .inputs ,
168
- protected_inputs = protected_inputs ,
169
- )
170
- ]
171
- if not candidate_inputs :
172
- return []
173
-
174
- candidate_pairs = [
175
- ((o , out ), (i , inp ))
176
- for o , out in enumerate (node .outputs )
177
- for i , inp in candidate_inputs
178
- if inp .type == out .type
179
- ]
151
+ candidate_pairs = self .filter_candidate_pairs (
152
+ fgraph , node , protected_inputs
153
+ )
180
154
181
155
if not candidate_pairs :
182
156
continue
@@ -216,13 +190,11 @@ def create_inplace_node(node, inplace_pattern):
216
190
inplace_pattern [o ] = [i ]
217
191
tried_inputs .add (i )
218
192
219
- inplace_node = create_inplace_node (node , inplace_pattern )
193
+ inplace_node = self . create_inplace_node (node , inplace_pattern )
220
194
if inplace_node .op .destroy_map == inplace_pattern :
221
195
replacements = tuple (zip (node .outputs , inplace_node .outputs ))
222
196
try :
223
- fgraph .replace_all_validate (
224
- replacements , reason = "inplace_elemwise_optimizer"
225
- )
197
+ fgraph .replace_all_validate (replacements , reason = reason )
226
198
except InconsistencyError :
227
199
prof ["nb_eager_inconsistent" ] += 1
228
200
else :
@@ -238,17 +210,15 @@ def create_inplace_node(node, inplace_pattern):
238
210
inplace_pattern [o ] = [i ]
239
211
tried_inputs .add (i )
240
212
241
- inplace_node = create_inplace_node (node , inplace_pattern )
213
+ inplace_node = self . create_inplace_node (node , inplace_pattern )
242
214
if inplace_node .op .destroy_map != inplace_pattern :
243
215
# This Op can't respect this partial inplace pattern,
244
216
# We assume it can't support any other cases
245
217
break
246
218
else :
247
219
replacements = tuple (zip (node .outputs , inplace_node .outputs ))
248
220
try :
249
- fgraph .replace_all_validate (
250
- replacements , reason = "inplace_elemwise_optimizer"
251
- )
221
+ fgraph .replace_all_validate (replacements , reason = reason )
252
222
node = inplace_node
253
223
replaced = True
254
224
except InconsistencyError :
@@ -278,6 +248,50 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1):
278
248
)
279
249
280
250
251
+ class InplaceElemwiseOptimizer (InplaceGraphOptimizer ):
252
+ op = Elemwise
253
+
254
+ def filter_candidate_pairs (self , fgraph , node , protected_inputs ):
255
+ candidate_inputs = [
256
+ (node .inputs .index (inp ), inp )
257
+ for inp in inplace_candidates (
258
+ fgraph ,
259
+ node .inputs ,
260
+ protected_inputs = protected_inputs ,
261
+ )
262
+ ]
263
+ if not candidate_inputs :
264
+ return []
265
+
266
+ return [
267
+ ((o , out ), (i , inp ))
268
+ for o , out in enumerate (node .outputs )
269
+ for i , inp in candidate_inputs
270
+ if inp .type == out .type
271
+ ]
272
+
273
+ def create_inplace_node (self , node , inplace_pattern ):
274
+ op = node .op
275
+ scalar_op = op .scalar_op
276
+ inplace_pattern = {i : o for i , [o ] in inplace_pattern .items ()}
277
+ if hasattr (scalar_op , "make_new_inplace" ):
278
+ new_scalar_op = scalar_op .make_new_inplace (
279
+ ps .transfer_type (
280
+ * [
281
+ inplace_pattern .get (i , o .dtype )
282
+ for i , o in enumerate (node .outputs )
283
+ ]
284
+ )
285
+ )
286
+ else :
287
+ new_scalar_op = type (scalar_op )(
288
+ ps .transfer_type (
289
+ * [inplace_pattern .get (i , None ) for i in range (len (node .outputs ))]
290
+ )
291
+ )
292
+ return type (op )(new_scalar_op , inplace_pattern ).make_node (* node .inputs )
293
+
294
+
281
295
compile .optdb .register (
282
296
"inplace_elemwise" ,
283
297
InplaceElemwiseOptimizer (),
0 commit comments