@@ -1239,6 +1239,158 @@ def _unwrap_all_tensors_from_functional(tensor_pytree, *, reapply_views: bool):
1239
1239
1240
1240
1241
1241
def functionalize (func : Callable , * , remove : str = 'mutations' ) -> Callable :
1242
+ """
1243
+ functionalize is a transform that can be used to remove (intermediate)
1244
+ mutations and aliasing from a function, while preserving the function's
1245
+ semantics.
1246
+
1247
+ ``functionalize(func)`` returns a new function with the same semantics
1248
+ as ``func``, but with all intermediate mutations removed.
1249
+ Every inplace operation performed on an intermediate tensor:
1250
+ ``intermediate.foo_()``
1251
+ gets replaced by its out-of-place equivalent:
1252
+ ``intermediate_updated = intermediate.foo()``.
1253
+
1254
+ functionalize is useful for shipping a pytorch program off to
1255
+ backends or compilers that aren't able to easily represent
1256
+ mutations or aliasing operators.
1257
+
1258
+ Args:
1259
+ func (Callable): A Python function that takes one or more arguments.
1260
+ remove (str): An optional string argument, that takes on either
1261
+ the value 'mutations' or 'mutations_and_views'.
1262
+ If 'mutations' is passed in then all mutating operators
1263
+ will be replaced with their non-mutating equivalents.
1264
+ If 'mutations_and_views' is passed in, then additionally, all aliasing
1265
+ operators will be replaced with their non-aliasing equivalents.
1266
+ Default: 'mutations'.
1267
+
1268
+ Returns:
1269
+ Returns a new "functionalized" function. It takes the same inputs as
1270
+ :attr:`func`, and has the same behavior, but any mutations
1271
+ (and optionally aliasing) performed on intermeidate tensors
1272
+ in the function will be removed.
1273
+
1274
+ functionalize will also remove mutations (and views) that were performed on function inputs.
1275
+ However to preserve semantics, functionalize will "fix up" the mutations after
1276
+ the transform has finished running, by detecting if any tensor inputs "should have"
1277
+ been mutated, and copying the new data back to the inputs if necessary.
1278
+
1279
+
1280
+ Example::
1281
+
1282
+ >>> import torch
1283
+ >>> from functorch import make_fx
1284
+ >>> from functorch.experimental import functionalize
1285
+ >>>
1286
+ >>> A function that uses mutations and views, but only on intermediate tensors.
1287
+ >>> def f(a):
1288
+ ... b = a + 1
1289
+ ... c = b.view(-1)
1290
+ ... c.add_(1)
1291
+ ... return b
1292
+ ...
1293
+ >>> inpt = torch.randn(2)
1294
+ >>>
1295
+ >>> out1 = f(inpt)
1296
+ >>> out2 = functionalize(f)(inpt)
1297
+ >>>
1298
+ >>> # semantics are the same (outputs are equivalent)
1299
+ >>> print(torch.allclose(out1, out2))
1300
+ True
1301
+ >>>
1302
+ >>> f_traced = make_fx(f)(inpt)
1303
+ >>> f_no_mutations_traced = make_fx(functionalize(f))(inpt)
1304
+ >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
1305
+ >>>
1306
+ >>> print(f_traced.code)
1307
+
1308
+
1309
+
1310
+ def forward(self, a_1):
1311
+ add = torch.ops.aten.add(a_1, 1); a_1 = None
1312
+ view = torch.ops.aten.view(add, [-1])
1313
+ add_ = torch.ops.aten.add_(view, 1); view = None
1314
+ return add
1315
+
1316
+ >>> print(f_no_mutations_traced.code)
1317
+
1318
+
1319
+
1320
+ def forward(self, a_1):
1321
+ add = torch.ops.aten.add(a_1, 1); a_1 = None
1322
+ view = torch.ops.aten.view(add, [-1]); add = None
1323
+ add_1 = torch.ops.aten.add(view, 1); view = None
1324
+ view_1 = torch.ops.aten.view(add_1, [2]); add_1 = None
1325
+ return view_1
1326
+
1327
+ >>> print(f_no_mutations_and_views_traced.code)
1328
+
1329
+
1330
+
1331
+ def forward(self, a_1):
1332
+ add = torch.ops.aten.add(a_1, 1); a_1 = None
1333
+ view_copy = torch.ops.aten.view_copy(add, [-1]); add = None
1334
+ add_1 = torch.ops.aten.add(view_copy, 1); view_copy = None
1335
+ view_copy_1 = torch.ops.aten.view_copy(add_1, [2]); add_1 = None
1336
+ return view_copy_1
1337
+
1338
+
1339
+ >>> A function that mutates its input tensor
1340
+ >>> def f(a):
1341
+ ... b = a.view(-1)
1342
+ ... b.add_(1)
1343
+ ... return a
1344
+ ...
1345
+ >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
1346
+ >>>
1347
+ >>> All mutations and views have been removed,
1348
+ >>> but there is an extra copy_ in the graph to correctly apply the mutation to the input
1349
+ >>> after the function has completed.
1350
+ >>> print(f_no_mutations_and_views_traced.code)
1351
+
1352
+
1353
+
1354
+ def forward(self, a_1):
1355
+ view_copy = torch.ops.aten.view_copy(a_1, [-1])
1356
+ add = torch.ops.aten.add(view_copy, 1); view_copy = None
1357
+ view_copy_1 = torch.ops.aten.view_copy(add, [2]); add = None
1358
+ copy_ = torch.ops.aten.copy_(a_1, view_copy_1); a_1 = None
1359
+ return view_copy_1
1360
+
1361
+
1362
+ There are a few "failure modes" for functionalize that are worth calling out:
1363
+ (1) Like other functorch transforms, `functionalize()` doesn't work with functions
1364
+ that directly use `.backward()`. The same is true for torch.autograd.grad.
1365
+ If you want to use autograd, you can compute gradients directly
1366
+ with `functionalize(grad(f))`.
1367
+ (2) Like other functorch transforms, `functionalize()` doesn't work with global state.
1368
+ If you call `functionalize(f)` on a function that takes views / mutations of
1369
+ non-local state, functionalization will simply no-op and pass the view/mutation
1370
+ calls directly to the backend.
1371
+ One way to work around this is is to ensure that any non-local state creation
1372
+ is wrapped into a larger function, which you then call functionalize on.
1373
+ (3) `resize_()` has some limitations: functionalize will only work on programs
1374
+ that use resize_()` as long as the tensor being resized is not a view.
1375
+ (4) `as_strided()` has some limitations: functionalize will not work on
1376
+ `as_strided()` calls that result in tensors with overlapping memory.
1377
+
1378
+
1379
+ Finally, a helpful mental model for understanding functionalization is that
1380
+ most user pytorch programs are writting with the public torch API.
1381
+ When executed, torch operators are generally decomposed into
1382
+ our internal C++ "ATen" API.
1383
+ The logic for functionalization happens entirely at the level of ATen.
1384
+ Functionalization knows how to take every aliasing operator in ATen,
1385
+ and map it to its non-aliasing equivalent
1386
+ (e.g. ``tensor.view({-1})`` -> ``at::view_copy(tensor, {-1})``),
1387
+ and how to take every mutating operator in ATen,
1388
+ and map it to its non-mutating equivalent
1389
+ (e.g. ``tensor.add_(1)`` -> ``at::add(tensor, -1)``),
1390
+ while tracking aliases and mutations out-of-line to know when to fix things up.
1391
+ Information about which ATen operators are aliasing or mutating all comes from
1392
+ https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml.
1393
+ """
1242
1394
if remove == 'mutations' :
1243
1395
reapply_views = True
1244
1396
elif remove == 'mutations_and_views' :
0 commit comments