@@ -206,6 +206,45 @@ input(s)'s memory). From there, go to the previous section.
206
206
Consider using :class: `DebugMode ` when developing
207
207
a new :class: `Op ` that uses :attr: `Op.view_map ` and/or :attr: `Op.destroy_map `.
208
208
209
+ The `inplace_on_inputs ` Method
210
+ ==============================
211
+
212
+ PyTensor provides a method :meth: `Op.inplace_on_inputs ` that allows an `Op ` to
213
+ create a version of itself that operates inplace on as many of the requested
214
+ inputs as possible while avoiding inplace operations on non-requested inputs.
215
+
216
+ This method takes a list of input indices where inplace operations are allowed
217
+ and returns a new `Op ` instance that will perform inplace operations only on
218
+ those inputs where it is safe and beneficial to do so.
219
+
220
+ .. testcode ::
221
+
222
+ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
223
+ """Try to return a version of self that tries to inplace in as many as `allowed_inplace_inputs `."""
224
+ # Implementation would create a new Op with appropriate destroy_map
225
+ # Return self by default if no inplace version is available
226
+ return self
227
+
228
+ Currently, this method is primarily used with Blockwise operations through PyTensor's
229
+ rewriting system, but it will be extended to support core ops directly in future versions.
230
+ The rewriting system automatically calls this method to optimize memory usage by
231
+ enabling inplace operations where they do not interfere with the computation graph's
232
+ correctness.
233
+
234
+ When implementing this method in a custom `Op `:
235
+
236
+ - Return a new instance of your `Op ` with a :attr: `destroy_map ` that reflects
237
+ the inplace operations on the allowed inputs
238
+ - Ensure that inplace operations are only performed on inputs that are in the
239
+ `allowed_inplace_inputs ` list
240
+ - Return `self ` if no inplace optimization is possible or beneficial
241
+ - The returned `Op ` should be functionally equivalent to the original but with
242
+ better memory efficiency
243
+
244
+ .. note ::
245
+ This method is automatically used by PyTensor's optimization pipeline and typically
246
+ does not need to be called directly by user code.
247
+
209
248
Inplace Rewriting and `DebugMode `
210
249
=================================
211
250
0 commit comments