Skip to content

Commit d768509

Browse files
authored
docs for functionalize (#874)
ghstack-source-id: c27a886 Pull Request resolved: #876
1 parent fb70a3c commit d768509

File tree

3 files changed

+165
-0
lines changed

3 files changed

+165
-0
lines changed

docs/source/experimental.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
functorch.experimental
2+
======================
3+
4+
.. currentmodule:: functorch.experimental
5+
6+
Experimental Function Transforms
7+
--------------------------------
8+
.. autosummary::
9+
:toctree: generated
10+
:nosignatures:
11+
12+
functionalize

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ Check out our `whirlwind tour <whirlwind_tour>`_ or some of our tutorials mentio
6363
:caption: API Reference and Notes
6464

6565
functorch
66+
experimental
6667
aot_autograd
6768

6869
.. toctree::

functorch/_src/eager_transforms.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,6 +1239,158 @@ def _unwrap_all_tensors_from_functional(tensor_pytree, *, reapply_views: bool):
12391239

12401240

12411241
def functionalize(func: Callable, *, remove: str = 'mutations') -> Callable:
1242+
"""
1243+
functionalize is a transform that can be used to remove (intermediate)
1244+
mutations and aliasing from a function, while preserving the function's
1245+
semantics.
1246+
1247+
``functionalize(func)`` returns a new function with the same semantics
1248+
as ``func``, but with all intermediate mutations removed.
1249+
Every inplace operation performed on an intermediate tensor:
1250+
``intermediate.foo_()``
1251+
gets replaced by its out-of-place equivalent:
1252+
``intermediate_updated = intermediate.foo()``.
1253+
1254+
functionalize is useful for shipping a pytorch program off to
1255+
backends or compilers that aren't able to easily represent
1256+
mutations or aliasing operators.
1257+
1258+
Args:
1259+
func (Callable): A Python function that takes one or more arguments.
1260+
remove (str): An optional string argument, that takes on either
1261+
the value 'mutations' or 'mutations_and_views'.
1262+
If 'mutations' is passed in then all mutating operators
1263+
will be replaced with their non-mutating equivalents.
1264+
If 'mutations_and_views' is passed in, then additionally, all aliasing
1265+
operators will be replaced with their non-aliasing equivalents.
1266+
Default: 'mutations'.
1267+
1268+
Returns:
1269+
Returns a new "functionalized" function. It takes the same inputs as
1270+
:attr:`func`, and has the same behavior, but any mutations
1271+
(and optionally aliasing) performed on intermeidate tensors
1272+
in the function will be removed.
1273+
1274+
functionalize will also remove mutations (and views) that were performed on function inputs.
1275+
However to preserve semantics, functionalize will "fix up" the mutations after
1276+
the transform has finished running, by detecting if any tensor inputs "should have"
1277+
been mutated, and copying the new data back to the inputs if necessary.
1278+
1279+
1280+
Example::
1281+
1282+
>>> import torch
1283+
>>> from functorch import make_fx
1284+
>>> from functorch.experimental import functionalize
1285+
>>>
1286+
>>> A function that uses mutations and views, but only on intermediate tensors.
1287+
>>> def f(a):
1288+
... b = a + 1
1289+
... c = b.view(-1)
1290+
... c.add_(1)
1291+
... return b
1292+
...
1293+
>>> inpt = torch.randn(2)
1294+
>>>
1295+
>>> out1 = f(inpt)
1296+
>>> out2 = functionalize(f)(inpt)
1297+
>>>
1298+
>>> # semantics are the same (outputs are equivalent)
1299+
>>> print(torch.allclose(out1, out2))
1300+
True
1301+
>>>
1302+
>>> f_traced = make_fx(f)(inpt)
1303+
>>> f_no_mutations_traced = make_fx(functionalize(f))(inpt)
1304+
>>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
1305+
>>>
1306+
>>> print(f_traced.code)
1307+
1308+
1309+
1310+
def forward(self, a_1):
1311+
add = torch.ops.aten.add(a_1, 1); a_1 = None
1312+
view = torch.ops.aten.view(add, [-1])
1313+
add_ = torch.ops.aten.add_(view, 1); view = None
1314+
return add
1315+
1316+
>>> print(f_no_mutations_traced.code)
1317+
1318+
1319+
1320+
def forward(self, a_1):
1321+
add = torch.ops.aten.add(a_1, 1); a_1 = None
1322+
view = torch.ops.aten.view(add, [-1]); add = None
1323+
add_1 = torch.ops.aten.add(view, 1); view = None
1324+
view_1 = torch.ops.aten.view(add_1, [2]); add_1 = None
1325+
return view_1
1326+
1327+
>>> print(f_no_mutations_and_views_traced.code)
1328+
1329+
1330+
1331+
def forward(self, a_1):
1332+
add = torch.ops.aten.add(a_1, 1); a_1 = None
1333+
view_copy = torch.ops.aten.view_copy(add, [-1]); add = None
1334+
add_1 = torch.ops.aten.add(view_copy, 1); view_copy = None
1335+
view_copy_1 = torch.ops.aten.view_copy(add_1, [2]); add_1 = None
1336+
return view_copy_1
1337+
1338+
1339+
>>> A function that mutates its input tensor
1340+
>>> def f(a):
1341+
... b = a.view(-1)
1342+
... b.add_(1)
1343+
... return a
1344+
...
1345+
>>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
1346+
>>>
1347+
>>> All mutations and views have been removed,
1348+
>>> but there is an extra copy_ in the graph to correctly apply the mutation to the input
1349+
>>> after the function has completed.
1350+
>>> print(f_no_mutations_and_views_traced.code)
1351+
1352+
1353+
1354+
def forward(self, a_1):
1355+
view_copy = torch.ops.aten.view_copy(a_1, [-1])
1356+
add = torch.ops.aten.add(view_copy, 1); view_copy = None
1357+
view_copy_1 = torch.ops.aten.view_copy(add, [2]); add = None
1358+
copy_ = torch.ops.aten.copy_(a_1, view_copy_1); a_1 = None
1359+
return view_copy_1
1360+
1361+
1362+
There are a few "failure modes" for functionalize that are worth calling out:
1363+
(1) Like other functorch transforms, `functionalize()` doesn't work with functions
1364+
that directly use `.backward()`. The same is true for torch.autograd.grad.
1365+
If you want to use autograd, you can compute gradients directly
1366+
with `functionalize(grad(f))`.
1367+
(2) Like other functorch transforms, `functionalize()` doesn't work with global state.
1368+
If you call `functionalize(f)` on a function that takes views / mutations of
1369+
non-local state, functionalization will simply no-op and pass the view/mutation
1370+
calls directly to the backend.
1371+
One way to work around this is is to ensure that any non-local state creation
1372+
is wrapped into a larger function, which you then call functionalize on.
1373+
(3) `resize_()` has some limitations: functionalize will only work on programs
1374+
that use resize_()` as long as the tensor being resized is not a view.
1375+
(4) `as_strided()` has some limitations: functionalize will not work on
1376+
`as_strided()` calls that result in tensors with overlapping memory.
1377+
1378+
1379+
Finally, a helpful mental model for understanding functionalization is that
1380+
most user pytorch programs are writting with the public torch API.
1381+
When executed, torch operators are generally decomposed into
1382+
our internal C++ "ATen" API.
1383+
The logic for functionalization happens entirely at the level of ATen.
1384+
Functionalization knows how to take every aliasing operator in ATen,
1385+
and map it to its non-aliasing equivalent
1386+
(e.g. ``tensor.view({-1})`` -> ``at::view_copy(tensor, {-1})``),
1387+
and how to take every mutating operator in ATen,
1388+
and map it to its non-mutating equivalent
1389+
(e.g. ``tensor.add_(1)`` -> ``at::add(tensor, -1)``),
1390+
while tracking aliases and mutations out-of-line to know when to fix things up.
1391+
Information about which ATen operators are aliasing or mutating all comes from
1392+
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml.
1393+
"""
12421394
if remove == 'mutations':
12431395
reapply_views = True
12441396
elif remove == 'mutations_and_views':

0 commit comments

Comments
 (0)