diff --git a/doc/extending/inplace.rst b/doc/extending/inplace.rst index 74ffa58119..a289290729 100644 --- a/doc/extending/inplace.rst +++ b/doc/extending/inplace.rst @@ -206,6 +206,45 @@ input(s)'s memory). From there, go to the previous section. Consider using :class:`DebugMode` when developing a new :class:`Op` that uses :attr:`Op.view_map` and/or :attr:`Op.destroy_map`. +The `inplace_on_inputs` Method +============================== + +PyTensor provides a method :meth:`Op.inplace_on_inputs` that allows an `Op` to +create a version of itself that operates inplace on as many of the requested +inputs as possible while avoiding inplace operations on non-requested inputs. + +This method takes a list of input indices where inplace operations are allowed +and returns a new `Op` instance that will perform inplace operations only on +those inputs where it is safe and beneficial to do so. + +.. testcode:: + + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": + """Try to return a version of self that tries to inplace in as many as `allowed_inplace_inputs`.""" + # Implementation would create a new Op with appropriate destroy_map + # Return self by default if no inplace version is available + return self + +Currently, this method is primarily used with Blockwise operations through PyTensor's +rewriting system, but it will be extended to support core ops directly in future versions. +The rewriting system automatically calls this method to optimize memory usage by +enabling inplace operations where they do not interfere with the computation graph's +correctness. + +When implementing this method in a custom `Op`: + +- Return a new instance of your `Op` with a :attr:`destroy_map` that reflects + the inplace operations on the allowed inputs +- Ensure that inplace operations are only performed on inputs that are in the + `allowed_inplace_inputs` list +- Return `self` if no inplace optimization is possible or beneficial +- The returned `Op` should be functionally equivalent to the original but with + better memory efficiency + +.. note:: + This method is automatically used by PyTensor's optimization pipeline and typically + does not need to be called directly by user code. + Inplace Rewriting and `DebugMode` ================================= diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index 3a00922c87..00490f7357 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -605,7 +605,6 @@ def make_thunk( def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": """Try to return a version of self that tries to inplace in as many as `allowed_inplace_inputs`.""" - # TODO: Document this in the Create your own Op docs # By default, do nothing return self