Skip to content

Commit 1ebd078

Browse files
committed
Reuse Elemwise inplace machinery for Blockwise
1 parent 1d94ed6 commit 1ebd078

File tree

3 files changed

+204
-116
lines changed

3 files changed

+204
-116
lines changed

pytensor/tensor/rewriting/blockwise.py

Lines changed: 62 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pytensor.graph import Constant, node_rewriter
33
from pytensor.graph.destroyhandler import inplace_candidates
44
from pytensor.graph.replace import vectorize_node
5-
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
5+
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
66
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
77
from pytensor.tensor.blockwise import Blockwise
88
from pytensor.tensor.math import Dot
@@ -11,6 +11,7 @@
1111
register_specialize,
1212
register_stabilize,
1313
)
14+
from pytensor.tensor.rewriting.elemwise import InplaceGraphOptimizer
1415
from pytensor.tensor.shape import Reshape
1516
from pytensor.tensor.subtensor import (
1617
AdvancedIncSubtensor,
@@ -260,68 +261,77 @@ def local_blockwise_of_subtensor(fgraph, node):
260261
return [x[(*none_slices, *core_idxs)]]
261262

262263

263-
@node_rewriter(tracks=[Blockwise], inplace=True)
264-
def blockwise_inplace(fgraph, node):
265-
blockwise_op = node.op
266-
267-
if blockwise_op.destroy_map:
268-
# Op already has inplace
269-
return
270-
271-
# Find out valid inputs for inplacing
272-
batch_ndim = blockwise_op.batch_ndim(node)
273-
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
274-
275-
inputs = node.inputs
276-
candidate_inputs = set(
277-
inplace_candidates(
278-
fgraph,
279-
[
280-
inp
281-
for inp in inputs
282-
if inp.type.broadcastable[:batch_ndim] == out_batch_bcast
283-
],
264+
class InplaceBlockwiseOptimizer(InplaceGraphOptimizer):
265+
op = Blockwise
266+
267+
def filter_candidate_pairs(self, fgraph, node, protected_inputs):
268+
blockwise_op = node.op
269+
batch_ndim = blockwise_op.batch_ndim(node)
270+
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
271+
inputs = node.inputs
272+
273+
candidate_inputs = set(
274+
inplace_candidates(
275+
fgraph,
276+
[
277+
inp
278+
for inp in inputs
279+
if inp.type.broadcastable[:batch_ndim] == out_batch_bcast
280+
],
281+
protected_inputs=protected_inputs,
282+
)
284283
)
285-
)
286-
allowed_inplace_inputs = [
287-
i for i, inp in enumerate(inputs) if inp in candidate_inputs
288-
]
289284

290-
if not allowed_inplace_inputs:
291-
return None
285+
allowed_inplace_inputs = [
286+
i for i, inp in enumerate(inputs) if inp in candidate_inputs
287+
]
288+
destroy_map = blockwise_op.core_op.inplace_on_inputs(
289+
allowed_inplace_inputs=allowed_inplace_inputs
290+
).destroy_map
291+
292+
if not destroy_map:
293+
return []
294+
295+
outputs = node.outputs
296+
return [
297+
((out_idx, outputs[out_idx]), (inp_idx, inputs[inp_idx]))
298+
for out_idx, inp_idxs in destroy_map.items()
299+
for inp_idx in inp_idxs
300+
]
292301

293-
inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
294-
allowed_inplace_inputs=allowed_inplace_inputs
295-
)
302+
def create_inplace_node(self, node, inplace_pattern):
303+
blockwise_op = node.op
304+
allowed_inplace_inputs = tuple(v[0] for v in inplace_pattern.values())
305+
inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
306+
allowed_inplace_inputs=allowed_inplace_inputs
307+
)
296308

297-
if not inplace_core_op.destroy_map:
298-
return None
309+
if not inplace_core_op.destroy_map:
310+
return node
299311

300-
# Check Op is not trying to inplace on non-candidate inputs
301-
for destroyed_inputs in inplace_core_op.destroy_map.values():
302-
for destroyed_input in destroyed_inputs:
303-
if destroyed_input not in allowed_inplace_inputs:
304-
raise ValueError(
305-
f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}"
306-
)
312+
# Check Op is not trying to inplace on non-candidate inputs
313+
for destroyed_inputs in inplace_core_op.destroy_map.values():
314+
for destroyed_input in destroyed_inputs:
315+
if destroyed_input not in allowed_inplace_inputs:
316+
raise ValueError(
317+
f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}"
318+
)
307319

308-
# Recreate core_op with inplace
309-
inplace_blockwise_op = Blockwise(
310-
core_op=inplace_core_op,
311-
signature=blockwise_op.signature,
312-
name=blockwise_op.name,
313-
gufunc_spec=blockwise_op.gufunc_spec,
314-
destroy_map=inplace_core_op.destroy_map,
315-
)
320+
# Recreate core_op with inplace
321+
inplace_blockwise_op = type(blockwise_op)(
322+
core_op=inplace_core_op,
323+
signature=blockwise_op.signature,
324+
name=blockwise_op.name,
325+
gufunc_spec=blockwise_op.gufunc_spec,
326+
destroy_map=inplace_core_op.destroy_map,
327+
)
316328

317-
out = inplace_blockwise_op.make_node(*node.inputs).outputs
318-
copy_stack_trace(node.outputs, out)
319-
return out
329+
return inplace_blockwise_op.make_node(*node.inputs)
320330

321331

322332
optdb.register(
323333
"blockwise_inplace",
324-
in2out(blockwise_inplace),
334+
InplaceBlockwiseOptimizer(),
325335
"fast_run",
326336
"inplace",
327337
position=50.1,

pytensor/tensor/rewriting/elemwise.py

Lines changed: 76 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import abc
12
import itertools
23
import operator
34
import sys
45
from collections import defaultdict, deque
5-
from collections.abc import Generator
6+
from collections.abc import Generator, Sequence
67
from functools import cache, reduce
78
from typing import TypeVar
89
from warnings import warn
@@ -12,7 +13,7 @@
1213
from pytensor.compile.function.types import Supervisor
1314
from pytensor.compile.mode import get_target_language
1415
from pytensor.configdefaults import config
15-
from pytensor.graph import FunctionGraph
16+
from pytensor.graph import FunctionGraph, Op
1617
from pytensor.graph.basic import Apply, Variable, ancestors
1718
from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates
1819
from pytensor.graph.features import ReplaceValidate
@@ -47,22 +48,31 @@
4748
from pytensor.tensor.variable import TensorConstant, TensorVariable
4849

4950

50-
class InplaceElemwiseOptimizer(GraphRewriter):
51-
r"""
52-
This is parameterized so that it works for `Elemwise` `Op`\s.
53-
"""
51+
class InplaceGraphOptimizer(GraphRewriter):
52+
op: type[Op]
5453

5554
def add_requirements(self, fgraph):
5655
fgraph.attach_feature(DestroyHandler())
5756

57+
@abc.abstractmethod
58+
def filter_candidate_pairs(
59+
self, fgraph: FunctionGraph, node: Apply, protected_inputs: Sequence[Variable]
60+
) -> Sequence[tuple[tuple[int, Variable], tuple[int, Variable]]]:
61+
pass
62+
63+
@abc.abstractmethod
64+
def create_inplace_node(
65+
self, node: Apply, inplace_pattern: dict[int, Sequence[int]]
66+
) -> Apply:
67+
pass
68+
5869
def apply(self, fgraph):
5970
r"""
6071
61-
Attempts to replace all `Elemwise`\s by versions of them that operate
62-
inplace. It operates greedily: for each `Elemwise` that is encountered,
63-
for each output, it tries each input to see if it can operate inplace
64-
on that input. If so, it makes the change and goes to the next output
65-
or `Elemwise`.
72+
Attempts to replace all `Op`\s by versions of them that operate
73+
inplace. It operates greedily: for each `Op` that is encountered,
74+
it tries to inplace all the valid inputs at once (if the Op supports it),
75+
if that fails, it tries to inplace one input at a time.
6676
6777
Examples
6878
--------
@@ -93,36 +103,13 @@ def apply(self, fgraph):
93103
# tackle them in a more general way. The whole try/except approach is probably suboptimal.
94104
# We can consider restricting inputs with static shapes that are large enough.
95105

96-
def create_inplace_node(node, inplace_pattern):
97-
op = node.op
98-
scalar_op = op.scalar_op
99-
inplace_pattern = {i: o for i, [o] in inplace_pattern.items()}
100-
if hasattr(scalar_op, "make_new_inplace"):
101-
new_scalar_op = scalar_op.make_new_inplace(
102-
ps.transfer_type(
103-
*[
104-
inplace_pattern.get(i, o.dtype)
105-
for i, o in enumerate(node.outputs)
106-
]
107-
)
108-
)
109-
else:
110-
new_scalar_op = type(scalar_op)(
111-
ps.transfer_type(
112-
*[
113-
inplace_pattern.get(i, None)
114-
for i in range(len(node.outputs))
115-
]
116-
)
117-
)
118-
return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs)
119-
120106
if config.tensor__insert_inplace_optimizer_validate_nb != -1:
121107
warn(
122108
"tensor__insert_inplace_optimizer_validate_nb config is deprecated. Setting it will fail in a future release.",
123109
FutureWarning,
124110
)
125111

112+
reason = f"{self.op}_inplace_optimizer"
126113
prof = {
127114
"opt": self,
128115
"node_before": len(fgraph.apply_nodes),
@@ -140,43 +127,30 @@ def create_inplace_node(node, inplace_pattern):
140127
protected_inputs.update(fgraph.outputs)
141128
root_destroyer = fgraph.destroy_handler.root_destroyer
142129

130+
self_op = self.op
143131
update_mapping = fgraph.update_mapping or {}
144132
op_updates: dict[TensorVariable, TensorVariable] = {
145133
out: fgraph.inputs[update_mapping[out_idx]]
146134
for out_idx, out in enumerate(fgraph.outputs)
147135
if (
148136
out_idx in update_mapping
149137
and out.owner
150-
and isinstance(out.owner.op, Elemwise)
138+
and isinstance(out.owner.op, self_op)
151139
)
152140
}
153141
set_op_updates = set(op_updates.keys())
154142

155143
for node in fgraph.toposort():
156-
if not isinstance(node.op, Elemwise) or node.op.destroy_map:
144+
if not isinstance(node.op, self_op) or node.op.destroy_map:
157145
continue
158146

159147
# If big graph and the outputs are scalar, do not make it inplace.
160148
if large_graph and all(node.outputs[0].type.broadcastable):
161149
continue
162150

163-
candidate_inputs = [
164-
(node.inputs.index(inp), inp)
165-
for inp in inplace_candidates(
166-
fgraph,
167-
node.inputs,
168-
protected_inputs=protected_inputs,
169-
)
170-
]
171-
if not candidate_inputs:
172-
return []
173-
174-
candidate_pairs = [
175-
((o, out), (i, inp))
176-
for o, out in enumerate(node.outputs)
177-
for i, inp in candidate_inputs
178-
if inp.type == out.type
179-
]
151+
candidate_pairs = self.filter_candidate_pairs(
152+
fgraph, node, protected_inputs
153+
)
180154

181155
if not candidate_pairs:
182156
continue
@@ -216,13 +190,11 @@ def create_inplace_node(node, inplace_pattern):
216190
inplace_pattern[o] = [i]
217191
tried_inputs.add(i)
218192

219-
inplace_node = create_inplace_node(node, inplace_pattern)
193+
inplace_node = self.create_inplace_node(node, inplace_pattern)
220194
if inplace_node.op.destroy_map == inplace_pattern:
221195
replacements = tuple(zip(node.outputs, inplace_node.outputs))
222196
try:
223-
fgraph.replace_all_validate(
224-
replacements, reason="inplace_elemwise_optimizer"
225-
)
197+
fgraph.replace_all_validate(replacements, reason=reason)
226198
except InconsistencyError:
227199
prof["nb_eager_inconsistent"] += 1
228200
else:
@@ -238,17 +210,15 @@ def create_inplace_node(node, inplace_pattern):
238210
inplace_pattern[o] = [i]
239211
tried_inputs.add(i)
240212

241-
inplace_node = create_inplace_node(node, inplace_pattern)
213+
inplace_node = self.create_inplace_node(node, inplace_pattern)
242214
if inplace_node.op.destroy_map != inplace_pattern:
243215
# This Op can't respect this partial inplace pattern,
244216
# We assume it can't support any other cases
245217
break
246218
else:
247219
replacements = tuple(zip(node.outputs, inplace_node.outputs))
248220
try:
249-
fgraph.replace_all_validate(
250-
replacements, reason="inplace_elemwise_optimizer"
251-
)
221+
fgraph.replace_all_validate(replacements, reason=reason)
252222
node = inplace_node
253223
replaced = True
254224
except InconsistencyError:
@@ -278,6 +248,50 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1):
278248
)
279249

280250

251+
class InplaceElemwiseOptimizer(InplaceGraphOptimizer):
252+
op = Elemwise
253+
254+
def filter_candidate_pairs(self, fgraph, node, protected_inputs):
255+
candidate_inputs = [
256+
(node.inputs.index(inp), inp)
257+
for inp in inplace_candidates(
258+
fgraph,
259+
node.inputs,
260+
protected_inputs=protected_inputs,
261+
)
262+
]
263+
if not candidate_inputs:
264+
return []
265+
266+
return [
267+
((o, out), (i, inp))
268+
for o, out in enumerate(node.outputs)
269+
for i, inp in candidate_inputs
270+
if inp.type == out.type
271+
]
272+
273+
def create_inplace_node(self, node, inplace_pattern):
274+
op = node.op
275+
scalar_op = op.scalar_op
276+
inplace_pattern = {i: o for i, [o] in inplace_pattern.items()}
277+
if hasattr(scalar_op, "make_new_inplace"):
278+
new_scalar_op = scalar_op.make_new_inplace(
279+
ps.transfer_type(
280+
*[
281+
inplace_pattern.get(i, o.dtype)
282+
for i, o in enumerate(node.outputs)
283+
]
284+
)
285+
)
286+
else:
287+
new_scalar_op = type(scalar_op)(
288+
ps.transfer_type(
289+
*[inplace_pattern.get(i, None) for i in range(len(node.outputs))]
290+
)
291+
)
292+
return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs)
293+
294+
281295
compile.optdb.register(
282296
"inplace_elemwise",
283297
InplaceElemwiseOptimizer(),

0 commit comments

Comments
 (0)