Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 56dff30

Browse files
committed
added compilation readme
1 parent db34173 commit 56dff30

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

COMPILE_README.md

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# An introduction to some of the (prototype) compilation features in Functorch
2+
3+
The primary compilation API we provide is something called AOTAutograd.
4+
5+
This is currently a prototype feature.
6+
7+
For example, here are some examples of how to use it.
8+
```
9+
from functorch.compile import aot_function, aot_module, draw_graph
10+
import torch.fx as fx
11+
import torch
12+
13+
# This simply prints out the FX graph of the forwards and the backwards
14+
def print_graph(name):
15+
def f(fx_g: fx.GraphModule, inps):
16+
print(name)
17+
print(fx_g.code)
18+
return fx_g
19+
return f
20+
21+
def f(x):
22+
return x.cos().cos()
23+
24+
nf = aot_function(f, fw_compiler=print_graph("forward"), bw_compiler=print_graph("backward"))
25+
nf(torch.randn(3, requires_grad=True))
26+
27+
# You can do whatever you want before and after, and you can still backprop through the function.
28+
inp = torch.randn(3, requires_grad=True)
29+
inp = inp.cos()
30+
out = nf(inp)
31+
out = out.sin().sum().backward()
32+
33+
def f(x):
34+
return x.cos().cos()
35+
36+
# This draws out the forwards and the backwards graphs as svg files
37+
def graph_drawer(name):
38+
def f(fx_g: fx.GraphModule, inps):
39+
draw_graph(fx_g, name)
40+
return fx_g
41+
return f
42+
43+
aot_function(f, fw_compiler=graph_drawer("forward"), bw_compiler=graph_drawer("backward"))(torch.randn(3, requires_grad=True))
44+
45+
# We also have a convenience API for applying AOTAutograd to modules
46+
from torchvision.models import resnet18
47+
aot_module(resnet18(), print_graph("forward"), print_graph("backward"))(torch.randn(1,3,200,200))
48+
# output elided since it's very long
49+
50+
# In practice, you might want to speed it up by sending it to Torchscript. You might also lower it to Torchscript before passing it to another compiler
51+
52+
def f(x):
53+
return x.cos().cos()
54+
55+
def ts_compiler(fx_g: fx.GraphModule, inps):
56+
f = torch.jit.script(fx_g)
57+
print(f.graph)
58+
f = torch.jit.freeze(f.eval()) # Note: This eval() works fine *even* though we're using this for training
59+
return f
60+
61+
aot_function(f, ts_compiler, ts_compiler)(torch.randn(3, requires_grad=True))
62+
```

0 commit comments

Comments
 (0)