Skip to content

Commit 35862ef

Browse files
committed
Inplace join
1 parent 0b56ed9 commit 35862ef

File tree

2 files changed

+160
-2
lines changed

2 files changed

+160
-2
lines changed

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit
66
from pytensor.link.utils import compile_function_src, unique_name_generator
77
from pytensor.tensor import TensorType
8+
from pytensor.tensor.rewriting.basic import BufferJoin
89
from pytensor.tensor.rewriting.subtensor import is_full_slice
910
from pytensor.tensor.subtensor import (
1011
AdvancedIncSubtensor,
@@ -103,6 +104,15 @@ def {function_name}({", ".join(input_names)}):
103104
return numba_njit(func, boundscheck=True)
104105

105106

107+
@numba_funcify.register(BufferJoin)
108+
def numba_funcify_buffer_join(op, node, **kwargs):
109+
@numba_njit
110+
def buffer_join(x, *args):
111+
return x
112+
113+
return buffer_join
114+
115+
106116
@numba_funcify.register(AdvancedSubtensor)
107117
@numba_funcify.register(AdvancedIncSubtensor)
108118
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):

pytensor/tensor/rewriting/basic.py

Lines changed: 150 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@
2323
"""
2424

2525
import logging
26+
import warnings
2627

2728
import numpy as np
2829

2930
import pytensor.scalar.basic as ps
3031
from pytensor import compile, config
3132
from pytensor.compile.ops import ViewOp
32-
from pytensor.graph import FunctionGraph
33-
from pytensor.graph.basic import Constant
33+
from pytensor.graph import FunctionGraph, rewrite_graph
34+
from pytensor.graph.basic import Apply, Constant
3435
from pytensor.graph.rewriting.basic import (
3536
NodeProcessingGraphRewriter,
3637
NodeRewriter,
@@ -39,8 +40,10 @@
3940
copy_stack_trace,
4041
in2out,
4142
node_rewriter,
43+
out2in,
4244
)
4345
from pytensor.graph.rewriting.db import RewriteDatabase
46+
from pytensor.link.c.op import COp
4447
from pytensor.raise_op import Assert, CheckAndRaise, assert_op
4548
from pytensor.scalar.basic import Second
4649
from pytensor.tensor.basic import (
@@ -55,6 +58,7 @@
5558
as_tensor_variable,
5659
atleast_Nd,
5760
cast,
61+
empty,
5862
fill,
5963
get_scalar_constant_value,
6064
join,
@@ -70,6 +74,9 @@
7074
from pytensor.tensor.extra_ops import broadcast_arrays
7175
from pytensor.tensor.math import Sum, add, eq, variadic_add
7276
from pytensor.tensor.shape import Shape_i, shape_padleft
77+
from pytensor.tensor.subtensor import (
78+
Subtensor,
79+
)
7380
from pytensor.tensor.type import DenseTensorType, TensorType
7481
from pytensor.tensor.variable import TensorConstant, TensorVariable
7582
from pytensor.utils import NoDuplicateOptWarningFilter
@@ -1356,3 +1363,144 @@ def local_join_of_alloc(fgraph, node):
13561363
new_out = alloc(new_join, *post_join_shape)
13571364
copy_stack_trace(node.outputs[0], new_out)
13581365
return [new_out]
1366+
1367+
1368+
class BufferSplit(Subtensor):
1369+
view_map = {} # It' a lie so PyTensor doesn't complain we are mutating the same input in parallel
1370+
1371+
def do_constant_folding(self, fgraph: "FunctionGraph", node: Apply) -> bool:
1372+
return False
1373+
1374+
1375+
class BufferJoin(COp):
1376+
"""
1377+
Returns an inplace view of the input. Used internally by PyTensor.
1378+
1379+
"""
1380+
1381+
# view_map = {0: [0]}
1382+
# Mapping from Type to C code (and version) to use.
1383+
# In the C code, the name of the input variable is %(iname)s,
1384+
# the output variable is %(oname)s.
1385+
c_code_and_version: dict = {}
1386+
__props__: tuple = ()
1387+
_f16_ok: bool = True
1388+
destroy_map = {0: [0]}
1389+
1390+
def make_node(self, buffer_source, *buffer_updates):
1391+
out = buffer_source.type()
1392+
return Apply(self, [buffer_source, *buffer_updates], [out])
1393+
1394+
def perform(self, node, inputs, output_storage):
1395+
output_storage[0][0] = inputs[0]
1396+
1397+
def c_code(self, node, nodename, inp, out, sub):
1398+
iname, *_ = inp
1399+
[oname] = out
1400+
fail = sub["fail"]
1401+
1402+
itype = node.inputs[0].type.__class__
1403+
if itype in self.c_code_and_version:
1404+
code, version = self.c_code_and_version[itype]
1405+
return code % locals()
1406+
1407+
# Else, no C code
1408+
raise NotImplementedError()
1409+
1410+
def c_code_cache_version(self):
1411+
version = []
1412+
# If any of the c code is unversioned, we have to return ()
1413+
# Else, we will return a list of (type name, version) pairs.
1414+
for t, (c, v) in sorted(
1415+
self.c_code_and_version.items(), key=lambda pair: str(pair[0])
1416+
):
1417+
if not v:
1418+
warnings.warn(
1419+
f"Type {t} has C code for ViewOp, but it has no "
1420+
"version. You should add a 'version' keyword "
1421+
"arg when calling register_view_op_c_code.",
1422+
stacklevel=2,
1423+
)
1424+
return ()
1425+
version.append((str(t), v))
1426+
1427+
return tuple(version)
1428+
1429+
def infer_shape(self, fgraph, node, input_shapes):
1430+
return [input_shapes[0]]
1431+
1432+
1433+
buffer_join = BufferJoin()
1434+
1435+
1436+
@node_rewriter([Join], inplace=True)
1437+
def inplace_join(fgraph, node):
1438+
axis, *tensors = node.inputs
1439+
1440+
if len(tensors) == 1:
1441+
return tensors
1442+
1443+
if not isinstance(axis, Constant):
1444+
return
1445+
1446+
shape_feature = getattr(fgraph, "shape_feature", None)
1447+
if shape_feature is None:
1448+
return
1449+
1450+
static_axis = int(axis.data)
1451+
1452+
[out] = node.outputs
1453+
out_shape = shape_feature.shape_of[out]
1454+
buffer = empty(out_shape, dtype=out.dtype)
1455+
1456+
empty_slices = (slice(None),) * static_axis
1457+
prev_start = None
1458+
buffer_updates = []
1459+
for i, y in enumerate(tensors):
1460+
if not (y.owner is not None and isinstance(y.owner.op, Elemwise)):
1461+
# We only know how to inplace Elemwise
1462+
return None
1463+
1464+
if prev_start is None:
1465+
end = shape_feature.shape_of[y][static_axis]
1466+
elif i == (len(tensors) - 1):
1467+
end = None
1468+
else:
1469+
end = prev_start + shape_feature.shape_of[y][static_axis]
1470+
tmp_subtensor = buffer[(*empty_slices, slice(prev_start, end))]
1471+
buffer_view = BufferSplit(tmp_subtensor.owner.op.idx_list)(
1472+
*tmp_subtensor.owner.inputs
1473+
)
1474+
prev_start = end
1475+
1476+
from pytensor.tensor.rewriting.elemwise import FusionOptimizer
1477+
1478+
scalar_inputs, scalar_outputs = FusionOptimizer.elemwise_to_scalar(
1479+
(*y.owner.inputs, buffer_view), y.owner.outputs
1480+
)
1481+
1482+
# Set y to override the buffer
1483+
inplace_pattern = dict(y.owner.op.inplace_pattern)
1484+
y_idx = y.owner.outputs.index(y)
1485+
inplace_pattern[y_idx] = len(y.owner.inputs)
1486+
1487+
new_op = Elemwise(
1488+
ps.Composite(scalar_inputs, scalar_outputs), inplace_pattern=inplace_pattern
1489+
)
1490+
buffer_update = new_op(*y.owner.inputs, buffer_view, return_list=True)[y_idx]
1491+
buffer_updates.append(buffer_update)
1492+
1493+
out = [buffer_join(buffer, *buffer_updates)]
1494+
out = rewrite_graph(
1495+
out, include=("canonicalize",), exclude=("local_useless_composite_outputs",)
1496+
)
1497+
return out
1498+
1499+
1500+
compile.optdb.register(
1501+
inplace_join.__name__,
1502+
out2in(inplace_join),
1503+
"fast_run",
1504+
"inplace",
1505+
position=50.51, # After the fusion inplace
1506+
)

0 commit comments

Comments
 (0)