|
5 | 5 | import logging
|
6 | 6 | import time
|
7 | 7 | import warnings
|
| 8 | +from collections.abc import Sequence |
8 | 9 | from itertools import chain
|
9 | 10 | from typing import TYPE_CHECKING
|
10 | 11 |
|
@@ -168,6 +169,59 @@ def validate(self, fgraph):
|
168 | 169 | raise InconsistencyError(f"Trying to destroy a protected variable: {r}")
|
169 | 170 |
|
170 | 171 |
|
| 172 | +def add_supervisor_to_fgraph( |
| 173 | + fgraph: FunctionGraph, |
| 174 | + input_specs: Sequence[SymbolicInput], |
| 175 | + accept_inplace: bool = False, |
| 176 | +) -> None: |
| 177 | + """Setup Supervisor Feature in a FunctionGraph, so that inplace rewrites can be used. |
| 178 | +
|
| 179 | + Parameters |
| 180 | + ---------- |
| 181 | + fgraph: FunctionGraph |
| 182 | + The FunctionGraph to setup the Supervisor Feature in. |
| 183 | + input_specs: Sequence of SymbolicInput |
| 184 | + The input specifications for the FunctionGraph. |
| 185 | + Inputs with the attribute `mutable=False` and which are not already destroyed by an inplace operation |
| 186 | + (if `accept_inplace` is True) will be protected from inplace operations. |
| 187 | + Otherwise, they will be allowed to be destroyed. |
| 188 | + accept_inplace: bool |
| 189 | + Whether to allow inplace operations to already be present in the graph. |
| 190 | +
|
| 191 | + Raises |
| 192 | + ------ |
| 193 | + TypeError |
| 194 | + If inplace operations are not allowed and the graph already contains inplace operations. |
| 195 | +
|
| 196 | + """ |
| 197 | + |
| 198 | + has_destroy_handler = hasattr(fgraph, "destroyers") |
| 199 | + if not (has_destroy_handler and accept_inplace): |
| 200 | + # Check if fgraph already contains destructive operations, |
| 201 | + # in which case we need to add a DestroyHandler or raise an error |
| 202 | + for node in fgraph.apply_nodes: |
| 203 | + if node.op.destroy_map: |
| 204 | + if not accept_inplace: |
| 205 | + raise TypeError( |
| 206 | + f"Graph must not contain inplace operations: {node}" |
| 207 | + ) |
| 208 | + else: |
| 209 | + has_destroy_handler = True |
| 210 | + fgraph.attach_feature(DestroyHandler()) |
| 211 | + break |
| 212 | + |
| 213 | + # Protect all immutable inputs from inplace operations. |
| 214 | + fgraph.attach_feature( |
| 215 | + Supervisor( |
| 216 | + input |
| 217 | + for spec, input in zip(input_specs, fgraph.inputs, strict=True) |
| 218 | + if not ( |
| 219 | + spec.mutable or has_destroy_handler and fgraph.has_destroyers([input]) |
| 220 | + ) |
| 221 | + ) |
| 222 | + ) |
| 223 | + |
| 224 | + |
171 | 225 | def std_fgraph(
|
172 | 226 | input_specs: list[SymbolicInput],
|
173 | 227 | output_specs: list[SymbolicOutput],
|
@@ -229,24 +283,8 @@ def std_fgraph(
|
229 | 283 |
|
230 | 284 | found_updates.extend(map(SymbolicOutput, updates))
|
231 | 285 |
|
232 |
| - for node in fgraph.apply_nodes: |
233 |
| - if node.op.destroy_map: |
234 |
| - if not accept_inplace: |
235 |
| - raise TypeError(f"Graph must not contain inplace operations: {node}") |
236 |
| - else: |
237 |
| - fgraph.attach_feature(DestroyHandler()) |
238 |
| - break |
239 |
| - |
240 |
| - # We need to protect all immutable inputs from inplace operations. |
241 |
| - fgraph.attach_feature( |
242 |
| - Supervisor( |
243 |
| - input |
244 |
| - for spec, input in zip(input_specs, fgraph.inputs, strict=True) |
245 |
| - if not ( |
246 |
| - spec.mutable |
247 |
| - or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input])) |
248 |
| - ) |
249 |
| - ) |
| 286 | + add_supervisor_to_fgraph( |
| 287 | + fgraph=fgraph, input_specs=input_specs, accept_inplace=accept_inplace |
250 | 288 | )
|
251 | 289 |
|
252 | 290 | # If named nodes are replaced, keep the name
|
|
0 commit comments