@@ -219,11 +219,67 @@ those inputs where it is safe and beneficial to do so.
219
219
220
220
.. testcode ::
221
221
222
- def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
223
- """Try to return a version of self that tries to inplace in as many as `allowed_inplace_inputs `."""
224
- # Implementation would create a new Op with appropriate destroy_map
225
- # Return self by default if no inplace version is available
226
- return self
222
+ import numpy as np
223
+ import pytensor
224
+ import pytensor.tensor as pt
225
+ from pytensor.graph.basic import Apply
226
+ from pytensor.graph.op import Op
227
+ from pytensor.tensor.blockwise import Blockwise
228
+
229
+ class MyOpWithInplace(Op):
230
+ __props__ = ("destroy_a",)
231
+
232
+ def __init__(self, destroy_a):
233
+ self.destroy_a = destroy_a
234
+ if destroy_a:
235
+ self.destroy_map = {0: [0]}
236
+
237
+ def make_node(self, a):
238
+ return Apply(self, [a], [a.type()])
239
+
240
+ def perform(self, node, inputs, output_storage):
241
+ [a] = inputs
242
+ if not self.destroy_a:
243
+ a = a.copy()
244
+ a[0] += 1
245
+ output_storage[0][0] = a
246
+
247
+ def inplace_on_inputs(self, allowed_inplace_inputs):
248
+ if 0 in allowed_inplace_inputs:
249
+ return MyOpWithInplace(destroy_a=True)
250
+ return self
251
+
252
+ a = pt.vector("a")
253
+ # Only Blockwise trigger inplace automatically for now
254
+ # Since the Blockwise isn't needed in this case, it will be removed after the inplace optimization
255
+ op = Blockwise(MyOpWithInplace(destroy_a=False), signature="(a)->(a)")
256
+ out = op(a)
257
+
258
+ # Give PyTensor permission to inplace on user provided inputs
259
+ fn = pytensor.function([pytensor.In(a, mutable=True)], out)
260
+
261
+ # Confirm that we have the inplace version of the Op
262
+ fn.dprint(print_destroy_map=True)
263
+
264
+ .. testoutput ::
265
+
266
+ Blockwise{MyOpWithInplace{destroy_a=True}, (a)->(a)} [id A] '' 5
267
+ └─ a [id B]
268
+
269
+ The output shows that the function now uses the inplace version (`destroy_a=True `).
270
+
271
+ .. testcode ::
272
+
273
+ # Test that inplace modification works
274
+ test_a = np.zeros(5)
275
+ result = fn(test_a)
276
+ print("Function result:", result)
277
+ print("Original array after function call:", test_a)
278
+
279
+ .. testoutput ::
280
+
281
+ Function result: [1. 0. 0. 0. 0.]
282
+ Original array after function call: [1. 0. 0. 0. 0.]
227
283
228
284
Currently, this method is primarily used with Blockwise operations through PyTensor's
229
285
rewriting system, but it will be extended to support core ops directly in future versions.
0 commit comments