|
| 1 | +# An introduction to some of the (prototype) compilation features in Functorch |
| 2 | + |
| 3 | +The primary compilation API we provide is something called AOTAutograd. |
| 4 | + |
| 5 | +This is currently a prototype feature. |
| 6 | + |
| 7 | +For example, here are some examples of how to use it. |
| 8 | +``` |
| 9 | +from functorch.compile import aot_function, aot_module, draw_graph |
| 10 | +import torch.fx as fx |
| 11 | +import torch |
| 12 | +
|
| 13 | +# This simply prints out the FX graph of the forwards and the backwards |
| 14 | +def print_graph(name): |
| 15 | + def f(fx_g: fx.GraphModule, inps): |
| 16 | + print(name) |
| 17 | + print(fx_g.code) |
| 18 | + return fx_g |
| 19 | + return f |
| 20 | +
|
| 21 | +def f(x): |
| 22 | + return x.cos().cos() |
| 23 | +
|
| 24 | +nf = aot_function(f, fw_compiler=print_graph("forward"), bw_compiler=print_graph("backward")) |
| 25 | +nf(torch.randn(3, requires_grad=True)) |
| 26 | +
|
| 27 | +# You can do whatever you want before and after, and you can still backprop through the function. |
| 28 | +inp = torch.randn(3, requires_grad=True) |
| 29 | +inp = inp.cos() |
| 30 | +out = nf(inp) |
| 31 | +out = out.sin().sum().backward() |
| 32 | +
|
| 33 | +def f(x): |
| 34 | + return x.cos().cos() |
| 35 | +
|
| 36 | +# This draws out the forwards and the backwards graphs as svg files |
| 37 | +def graph_drawer(name): |
| 38 | + def f(fx_g: fx.GraphModule, inps): |
| 39 | + draw_graph(fx_g, name) |
| 40 | + return fx_g |
| 41 | + return f |
| 42 | +
|
| 43 | +aot_function(f, fw_compiler=graph_drawer("forward"), bw_compiler=graph_drawer("backward"))(torch.randn(3, requires_grad=True)) |
| 44 | +
|
| 45 | +# We also have a convenience API for applying AOTAutograd to modules |
| 46 | +from torchvision.models import resnet18 |
| 47 | +aot_module(resnet18(), print_graph("forward"), print_graph("backward"))(torch.randn(1,3,200,200)) |
| 48 | +# output elided since it's very long |
| 49 | +
|
| 50 | +# In practice, you might want to speed it up by sending it to Torchscript. You might also lower it to Torchscript before passing it to another compiler |
| 51 | +
|
| 52 | +def f(x): |
| 53 | + return x.cos().cos() |
| 54 | +
|
| 55 | +def ts_compiler(fx_g: fx.GraphModule, inps): |
| 56 | + f = torch.jit.script(fx_g) |
| 57 | + print(f.graph) |
| 58 | + f = torch.jit.freeze(f.eval()) # Note: This eval() works fine *even* though we're using this for training |
| 59 | + return f |
| 60 | +
|
| 61 | +aot_function(f, ts_compiler, ts_compiler)(torch.randn(3, requires_grad=True)) |
| 62 | +``` |
0 commit comments