Skip to content

Commit a0f8eff

Browse files
CopilotricardoV94
andcommitted
Replace code snippet with minimal self-contained example for inplace_on_inputs
Co-authored-by: ricardoV94 <[email protected]>
1 parent 7d0ee2d commit a0f8eff

File tree

1 file changed

+61
-5
lines changed

1 file changed

+61
-5
lines changed

doc/extending/inplace.rst

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,11 +219,67 @@ those inputs where it is safe and beneficial to do so.
219219

220220
.. testcode::
221221

222-
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
223-
"""Try to return a version of self that tries to inplace in as many as `allowed_inplace_inputs`."""
224-
# Implementation would create a new Op with appropriate destroy_map
225-
# Return self by default if no inplace version is available
226-
return self
222+
import numpy as np
223+
import pytensor
224+
import pytensor.tensor as pt
225+
from pytensor.graph.basic import Apply
226+
from pytensor.graph.op import Op
227+
from pytensor.tensor.blockwise import Blockwise
228+
229+
class MyOpWithInplace(Op):
230+
__props__ = ("destroy_a",)
231+
232+
def __init__(self, destroy_a):
233+
self.destroy_a = destroy_a
234+
if destroy_a:
235+
self.destroy_map = {0: [0]}
236+
237+
def make_node(self, a):
238+
return Apply(self, [a], [a.type()])
239+
240+
def perform(self, node, inputs, output_storage):
241+
[a] = inputs
242+
if not self.destroy_a:
243+
a = a.copy()
244+
a[0] += 1
245+
output_storage[0][0] = a
246+
247+
def inplace_on_inputs(self, allowed_inplace_inputs):
248+
if 0 in allowed_inplace_inputs:
249+
return MyOpWithInplace(destroy_a=True)
250+
return self
251+
252+
a = pt.vector("a")
253+
# Only Blockwise trigger inplace automatically for now
254+
# Since the Blockwise isn't needed in this case, it will be removed after the inplace optimization
255+
op = Blockwise(MyOpWithInplace(destroy_a=False), signature="(a)->(a)")
256+
out = op(a)
257+
258+
# Give PyTensor permission to inplace on user provided inputs
259+
fn = pytensor.function([pytensor.In(a, mutable=True)], out)
260+
261+
# Confirm that we have the inplace version of the Op
262+
fn.dprint(print_destroy_map=True)
263+
264+
.. testoutput::
265+
266+
Blockwise{MyOpWithInplace{destroy_a=True}, (a)->(a)} [id A] '' 5
267+
└─ a [id B]
268+
269+
The output shows that the function now uses the inplace version (`destroy_a=True`).
270+
271+
.. testcode::
272+
273+
# Test that inplace modification works
274+
test_a = np.zeros(5)
275+
result = fn(test_a)
276+
print("Function result:", result)
277+
print("Original array after function call:", test_a)
278+
279+
.. testoutput::
280+
281+
Function result: [1. 0. 0. 0. 0.]
282+
Original array after function call: [1. 0. 0. 0. 0.]
227283

228284
Currently, this method is primarily used with Blockwise operations through PyTensor's
229285
rewriting system, but it will be extended to support core ops directly in future versions.

0 commit comments

Comments
 (0)