Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pytensor/link/numba/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit
from pytensor.link.utils import compile_function_src, unique_name_generator
from pytensor.tensor import TensorType
from pytensor.tensor.rewriting.basic import BufferJoin
from pytensor.tensor.rewriting.subtensor import is_full_slice
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
Expand Down Expand Up @@ -103,6 +104,15 @@ def {function_name}({", ".join(input_names)}):
return numba_njit(func, boundscheck=True)


@numba_funcify.register(BufferJoin)
def numba_funcify_buffer_join(op, node, **kwargs):
@numba_njit
def buffer_join(x, *args):
return x

return buffer_join


@numba_funcify.register(AdvancedSubtensor)
@numba_funcify.register(AdvancedIncSubtensor)
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
Expand Down
152 changes: 150 additions & 2 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
"""

import logging
import warnings

import numpy as np

import pytensor.scalar.basic as ps
from pytensor import compile, config
from pytensor.compile.ops import ViewOp
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Constant
from pytensor.graph import FunctionGraph, rewrite_graph
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter,
NodeRewriter,
Expand All @@ -39,8 +40,10 @@
copy_stack_trace,
in2out,
node_rewriter,
out2in,
)
from pytensor.graph.rewriting.db import RewriteDatabase
from pytensor.link.c.op import COp
from pytensor.raise_op import Assert, CheckAndRaise, assert_op
from pytensor.scalar.basic import Second
from pytensor.tensor.basic import (
Expand All @@ -55,6 +58,7 @@
as_tensor_variable,
atleast_Nd,
cast,
empty,
fill,
get_scalar_constant_value,
join,
Expand All @@ -70,6 +74,9 @@
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.math import Sum, add, eq, variadic_add
from pytensor.tensor.shape import Shape_i, shape_padleft
from pytensor.tensor.subtensor import (
Subtensor,
)
from pytensor.tensor.type import DenseTensorType, TensorType
from pytensor.tensor.variable import TensorConstant, TensorVariable
from pytensor.utils import NoDuplicateOptWarningFilter
Expand Down Expand Up @@ -1356,3 +1363,144 @@ def local_join_of_alloc(fgraph, node):
new_out = alloc(new_join, *post_join_shape)
copy_stack_trace(node.outputs[0], new_out)
return [new_out]


class BufferSplit(Subtensor):
view_map = {} # It' a lie so PyTensor doesn't complain we are mutating the same input in parallel

def do_constant_folding(self, fgraph: "FunctionGraph", node: Apply) -> bool:
return False


class BufferJoin(COp):
"""
Returns an inplace view of the input. Used internally by PyTensor.

"""

# view_map = {0: [0]}
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version: dict = {}
__props__: tuple = ()
_f16_ok: bool = True
destroy_map = {0: [0]}

def make_node(self, buffer_source, *buffer_updates):
out = buffer_source.type()
return Apply(self, [buffer_source, *buffer_updates], [out])

def perform(self, node, inputs, output_storage):
output_storage[0][0] = inputs[0]

def c_code(self, node, nodename, inp, out, sub):
iname, *_ = inp
[oname] = out
fail = sub["fail"]

itype = node.inputs[0].type.__class__
if itype in self.c_code_and_version:
code, version = self.c_code_and_version[itype]
return code % locals()

# Else, no C code
raise NotImplementedError()

def c_code_cache_version(self):
version = []
# If any of the c code is unversioned, we have to return ()
# Else, we will return a list of (type name, version) pairs.
for t, (c, v) in sorted(
self.c_code_and_version.items(), key=lambda pair: str(pair[0])
):
if not v:
warnings.warn(
f"Type {t} has C code for ViewOp, but it has no "
"version. You should add a 'version' keyword "
"arg when calling register_view_op_c_code.",
stacklevel=2,
)
return ()
version.append((str(t), v))

return tuple(version)

def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]


buffer_join = BufferJoin()


@node_rewriter([Join], inplace=True)
def inplace_join(fgraph, node):
axis, *tensors = node.inputs

if len(tensors) == 1:
return tensors

if not isinstance(axis, Constant):
return

shape_feature = getattr(fgraph, "shape_feature", None)
if shape_feature is None:
return

static_axis = int(axis.data)

[out] = node.outputs
out_shape = shape_feature.shape_of[out]
buffer = empty(out_shape, dtype=out.dtype)

empty_slices = (slice(None),) * static_axis
prev_start = None
buffer_updates = []
for i, y in enumerate(tensors):
if not (y.owner is not None and isinstance(y.owner.op, Elemwise)):
# We only know how to inplace Elemwise
return None

if prev_start is None:
end = shape_feature.shape_of[y][static_axis]
elif i == (len(tensors) - 1):
end = None
else:
end = prev_start + shape_feature.shape_of[y][static_axis]
tmp_subtensor = buffer[(*empty_slices, slice(prev_start, end))]
buffer_view = BufferSplit(tmp_subtensor.owner.op.idx_list)(
*tmp_subtensor.owner.inputs
)
prev_start = end

from pytensor.tensor.rewriting.elemwise import FusionOptimizer

scalar_inputs, scalar_outputs = FusionOptimizer.elemwise_to_scalar(
(*y.owner.inputs, buffer_view), y.owner.outputs
)

# Set y to override the buffer
inplace_pattern = dict(y.owner.op.inplace_pattern)
y_idx = y.owner.outputs.index(y)
inplace_pattern[y_idx] = len(y.owner.inputs)

new_op = Elemwise(
ps.Composite(scalar_inputs, scalar_outputs), inplace_pattern=inplace_pattern
)
buffer_update = new_op(*y.owner.inputs, buffer_view, return_list=True)[y_idx]
buffer_updates.append(buffer_update)

out = [buffer_join(buffer, *buffer_updates)]
out = rewrite_graph(
out, include=("canonicalize",), exclude=("local_useless_composite_outputs",)
)
return out


compile.optdb.register(
inplace_join.__name__,
out2in(inplace_join),
"fast_run",
"inplace",
position=50.51, # After the fusion inplace
)