From 35862ef0f9168fc57cfaa8ce37d7b99c6091cc4e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 31 Mar 2025 16:11:36 +0200 Subject: [PATCH] Inplace join --- pytensor/link/numba/dispatch/subtensor.py | 10 ++ pytensor/tensor/rewriting/basic.py | 152 +++++++++++++++++++++- 2 files changed, 160 insertions(+), 2 deletions(-) diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index ee9e183d16..35980d34f1 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -5,6 +5,7 @@ from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit from pytensor.link.utils import compile_function_src, unique_name_generator from pytensor.tensor import TensorType +from pytensor.tensor.rewriting.basic import BufferJoin from pytensor.tensor.rewriting.subtensor import is_full_slice from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, @@ -103,6 +104,15 @@ def {function_name}({", ".join(input_names)}): return numba_njit(func, boundscheck=True) +@numba_funcify.register(BufferJoin) +def numba_funcify_buffer_join(op, node, **kwargs): + @numba_njit + def buffer_join(x, *args): + return x + + return buffer_join + + @numba_funcify.register(AdvancedSubtensor) @numba_funcify.register(AdvancedIncSubtensor) def numba_funcify_AdvancedSubtensor(op, node, **kwargs): diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 59148fae3b..90b2b37966 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -23,14 +23,15 @@ """ import logging +import warnings import numpy as np import pytensor.scalar.basic as ps from pytensor import compile, config from pytensor.compile.ops import ViewOp -from pytensor.graph import FunctionGraph -from pytensor.graph.basic import Constant +from pytensor.graph import FunctionGraph, rewrite_graph +from pytensor.graph.basic import Apply, Constant from pytensor.graph.rewriting.basic import ( NodeProcessingGraphRewriter, NodeRewriter, @@ -39,8 +40,10 @@ copy_stack_trace, in2out, node_rewriter, + out2in, ) from pytensor.graph.rewriting.db import RewriteDatabase +from pytensor.link.c.op import COp from pytensor.raise_op import Assert, CheckAndRaise, assert_op from pytensor.scalar.basic import Second from pytensor.tensor.basic import ( @@ -55,6 +58,7 @@ as_tensor_variable, atleast_Nd, cast, + empty, fill, get_scalar_constant_value, join, @@ -70,6 +74,9 @@ from pytensor.tensor.extra_ops import broadcast_arrays from pytensor.tensor.math import Sum, add, eq, variadic_add from pytensor.tensor.shape import Shape_i, shape_padleft +from pytensor.tensor.subtensor import ( + Subtensor, +) from pytensor.tensor.type import DenseTensorType, TensorType from pytensor.tensor.variable import TensorConstant, TensorVariable from pytensor.utils import NoDuplicateOptWarningFilter @@ -1356,3 +1363,144 @@ def local_join_of_alloc(fgraph, node): new_out = alloc(new_join, *post_join_shape) copy_stack_trace(node.outputs[0], new_out) return [new_out] + + +class BufferSplit(Subtensor): + view_map = {} # It' a lie so PyTensor doesn't complain we are mutating the same input in parallel + + def do_constant_folding(self, fgraph: "FunctionGraph", node: Apply) -> bool: + return False + + +class BufferJoin(COp): + """ + Returns an inplace view of the input. Used internally by PyTensor. + + """ + + # view_map = {0: [0]} + # Mapping from Type to C code (and version) to use. + # In the C code, the name of the input variable is %(iname)s, + # the output variable is %(oname)s. + c_code_and_version: dict = {} + __props__: tuple = () + _f16_ok: bool = True + destroy_map = {0: [0]} + + def make_node(self, buffer_source, *buffer_updates): + out = buffer_source.type() + return Apply(self, [buffer_source, *buffer_updates], [out]) + + def perform(self, node, inputs, output_storage): + output_storage[0][0] = inputs[0] + + def c_code(self, node, nodename, inp, out, sub): + iname, *_ = inp + [oname] = out + fail = sub["fail"] + + itype = node.inputs[0].type.__class__ + if itype in self.c_code_and_version: + code, version = self.c_code_and_version[itype] + return code % locals() + + # Else, no C code + raise NotImplementedError() + + def c_code_cache_version(self): + version = [] + # If any of the c code is unversioned, we have to return () + # Else, we will return a list of (type name, version) pairs. + for t, (c, v) in sorted( + self.c_code_and_version.items(), key=lambda pair: str(pair[0]) + ): + if not v: + warnings.warn( + f"Type {t} has C code for ViewOp, but it has no " + "version. You should add a 'version' keyword " + "arg when calling register_view_op_c_code.", + stacklevel=2, + ) + return () + version.append((str(t), v)) + + return tuple(version) + + def infer_shape(self, fgraph, node, input_shapes): + return [input_shapes[0]] + + +buffer_join = BufferJoin() + + +@node_rewriter([Join], inplace=True) +def inplace_join(fgraph, node): + axis, *tensors = node.inputs + + if len(tensors) == 1: + return tensors + + if not isinstance(axis, Constant): + return + + shape_feature = getattr(fgraph, "shape_feature", None) + if shape_feature is None: + return + + static_axis = int(axis.data) + + [out] = node.outputs + out_shape = shape_feature.shape_of[out] + buffer = empty(out_shape, dtype=out.dtype) + + empty_slices = (slice(None),) * static_axis + prev_start = None + buffer_updates = [] + for i, y in enumerate(tensors): + if not (y.owner is not None and isinstance(y.owner.op, Elemwise)): + # We only know how to inplace Elemwise + return None + + if prev_start is None: + end = shape_feature.shape_of[y][static_axis] + elif i == (len(tensors) - 1): + end = None + else: + end = prev_start + shape_feature.shape_of[y][static_axis] + tmp_subtensor = buffer[(*empty_slices, slice(prev_start, end))] + buffer_view = BufferSplit(tmp_subtensor.owner.op.idx_list)( + *tmp_subtensor.owner.inputs + ) + prev_start = end + + from pytensor.tensor.rewriting.elemwise import FusionOptimizer + + scalar_inputs, scalar_outputs = FusionOptimizer.elemwise_to_scalar( + (*y.owner.inputs, buffer_view), y.owner.outputs + ) + + # Set y to override the buffer + inplace_pattern = dict(y.owner.op.inplace_pattern) + y_idx = y.owner.outputs.index(y) + inplace_pattern[y_idx] = len(y.owner.inputs) + + new_op = Elemwise( + ps.Composite(scalar_inputs, scalar_outputs), inplace_pattern=inplace_pattern + ) + buffer_update = new_op(*y.owner.inputs, buffer_view, return_list=True)[y_idx] + buffer_updates.append(buffer_update) + + out = [buffer_join(buffer, *buffer_updates)] + out = rewrite_graph( + out, include=("canonicalize",), exclude=("local_useless_composite_outputs",) + ) + return out + + +compile.optdb.register( + inplace_join.__name__, + out2in(inplace_join), + "fast_run", + "inplace", + position=50.51, # After the fusion inplace +)