Skip to content

Commit c525a02

Browse files
authored
[dynamo, docs] cherry pick torch.compile programming model docs into 2.8 (pytorch#159373)
* [dynamo, docs] cherry pick torch.compile programming model docs into 2.8 * revert requirements-docs.txt * add remaining docs, update conf.py with myst_nb
1 parent a1cb3cc commit c525a02

20 files changed

+1877
-1
lines changed
424 KB
Loading

docs/source/compile/header_code.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import functools
2+
import os
3+
4+
import torch
5+
6+
7+
# to lower notebook execution time while hiding backend="eager"
8+
torch.compile = functools.partial(torch.compile, backend="eager")
9+
10+
# to clear torch logs format
11+
os.environ["TORCH_LOGS_FORMAT"] = ""
12+
torch._logging._internal.DEFAULT_FORMATTER = (
13+
torch._logging._internal._default_formatter()
14+
)
15+
torch._logging._internal._init_logs()
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
---
2+
file_format: mystnb
3+
kernelspec:
4+
name: python3
5+
mystnb:
6+
execution_timeout: 30
7+
execution_show_tb: True
8+
merge_streams: True
9+
---
10+
11+
```{code-cell}
12+
:tags: [remove-cell]
13+
import torch
14+
15+
import header_code
16+
17+
torch._logging.set_logs(graph_breaks=True)
18+
```
19+
20+
# Common Graph Breaks
21+
22+
Below are some common graph breaks and some workarounds.
23+
24+
## Incorrect Code
25+
Your code might contain errors (meaning it doesn't execute even without `torch.compile`). In the example below, there's a typo in the `torch.sin` call due to an extra argument. **Always disable `torch.compile` to check if the code runs correctly.**
26+
27+
28+
```{code-cell}
29+
@torch.compile
30+
def fn(x):
31+
y = torch.sin(x, x)
32+
return y
33+
34+
try:
35+
fn(torch.ones(3, 3))
36+
except Exception as e:
37+
pass
38+
```
39+
40+
Dynamo makes a best-effort attempt to hint if a graph break is caused by your code.
41+
But it can still sometimes be difficult to tell from the logs if the graph break is caused by an error in your code,
42+
is a more complicated graph break, or is a `torch.compile` bug. In order to differentiate, we recommend trying to run your code without `torch.compile` to see if you still get the error reported by the graph break.
43+
44+
## Data-dependent operations
45+
46+
`torch.compile` graph breaks on data-dependent operations such as data-dependent control flow (if-statements, loops with tensors) and direct tensor data accesses (`.item`, `.data_ptr`).
47+
48+
```{code-cell}
49+
@torch.compile
50+
def fn(x):
51+
y = x.sum()
52+
if y > 0:
53+
return x + y.item()
54+
return x - y.item()
55+
56+
print(fn(torch.ones(3, 3)))
57+
```
58+
59+
The general workaround for these graph breaks is to avoid doing data-dependent operations. Some specific workarounds are:
60+
61+
- If your control flow doesn't actually depend on data values, consider modifying your code to perform control flow on constants.
62+
63+
64+
```{code-cell}
65+
# old
66+
x = torch.randn(3, 3)
67+
@torch.compile
68+
def fn(y):
69+
if x.sum() > 0:
70+
return y + x
71+
else:
72+
return y - x
73+
74+
print(fn(torch.ones(3, 3)))
75+
```
76+
77+
```{code-cell}
78+
# new
79+
x = torch.randn(3, 3)
80+
cond = (x.sum() > 0).item()
81+
@torch.compile
82+
def fn(y):
83+
if cond:
84+
return y + x
85+
else:
86+
return y - x
87+
88+
print(fn(torch.ones(3, 3)))
89+
```
90+
91+
- Use higher-order ops like {ref}`cond` in place of data-dependent control flow
92+
93+
94+
```{code-cell}
95+
# old
96+
@torch.compile
97+
def fn(x):
98+
if x.sum() > 0:
99+
return x + 1
100+
return x - 1
101+
102+
print(fn(torch.ones(3, 3)))
103+
```
104+
105+
```{code-cell}
106+
# new
107+
@torch.compile
108+
def fn(x):
109+
return torch.cond(
110+
x.sum() > 0,
111+
lambda x: x + 1,
112+
lambda x: x - 1,
113+
(x,),
114+
)
115+
116+
print(fn(torch.ones(3, 3)))
117+
```
118+
119+
- If you have a `.item()` call, try `torch._dynamo.config.capture_scalar_outputs = True`
120+
or `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1`.
121+
- Wrap problematic parts of the function in a custom operator
122+
123+
## Printing and logging
124+
125+
Printing/logging/issuing warnings will result in a graph break.
126+
You can try working around this by using `torch._dynamo.config.reorderable_logging_functions`.
127+
This config is used to reorder logging functions so that they are called at the end of the
128+
traced function, thus avoiding a graph break.
129+
However, the logged contents may differ if, for example, a mutation occurs.
130+
131+
132+
```{code-cell}
133+
torch._dynamo.config.reorderable_logging_functions.add(print)
134+
135+
@torch.compile
136+
def fn(x):
137+
x += 1
138+
print("log!")
139+
return torch.sin(x)
140+
141+
print(fn(torch.ones(3, 3)))
142+
```
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
---
2+
file_format: mystnb
3+
kernelspec:
4+
name: python3
5+
mystnb:
6+
execution_timeout: 30
7+
execution_show_tb: True
8+
merge_streams: True
9+
---
10+
11+
```{code-cell}
12+
:tags: [remove-cell]
13+
import torch
14+
15+
import header_code
16+
17+
torch._logging.set_logs(graph_breaks=True, graph_code=True)
18+
```
19+
20+
# Disabling and Suppressing Errors
21+
For some model architectures, there are portions of the model which are particularly difficult to compile -
22+
either there are many graph breaks, or there are crashes.
23+
You may want to explicitly disable these portions of the model which are problematic so that you can apply
24+
`torch.compile` to the parts that work. You can do this by using the `@torch.compiler.disable` decorator.
25+
When `torch.compile` attempts to call a disabled function, it breaks the graph and skips tracing the disabled function,
26+
resuming tracing after the call. By default, all recursive calls made from a disabled function are also disabled.
27+
Use the `recursive=False` option to allow compilation for recursive calls.
28+
29+
```{code-cell}
30+
def inner1(x):
31+
torch._dynamo.graph_break() # not traced
32+
return x + 1 # not traced
33+
34+
@torch.compiler.disable
35+
def outer1(x):
36+
x = x + 2 # not traced
37+
torch._dynamo.graph_break() # not traced
38+
return inner1(x)
39+
40+
@torch.compile
41+
def f(x):
42+
x = outer1(x)
43+
return x + 4 # traced
44+
45+
print(f(torch.ones(3)))
46+
```
47+
48+
```{code-cell}
49+
def inner2(x):
50+
torch._dynamo.graph_break() # traced
51+
return x + 1 # traced
52+
53+
@torch.compiler.disable(recursive=False)
54+
def outer2(x):
55+
x = x + 2 # not traced
56+
torch._dynamo.graph_break() # not traced
57+
return inner2(x)
58+
59+
@torch.compile
60+
def g(x):
61+
x = outer2(x)
62+
return x + 4 # traced
63+
64+
print(g(torch.ones(3)))
65+
```
66+
67+
For example, one can use `torch.compiler.disable` to disable `torch.compile` on sparse architecture in
68+
recommendation models, as the sparse arch is difficult to compile.
69+
Preprocessing and logging functions are other examples of functions that typically cause
70+
a lot of graph breaks and do not get value from being compiled.
71+
72+
If you are experiencing compiler crashes and you want to continue regardless,
73+
you can set `torch._dynamo.config.suppress_errors = True`.
74+
When the compiler crashes, we will just skip tracing the function and try again later.
75+
**This is not best practice** - it is better to eventually manually add `disable` annotations as necessary.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Custom Operators
2+
3+
**Summary:**
4+
- Use custom operators to have `torch.compile` treat a function as opaque. `torch.compile` will never trace into the function and Inductor (the backend) will run the function as-is.
5+
6+
You may wish to use a custom operator in any of the following situations:
7+
- Your code calls some C/C++/CUDA code. Dynamo is a Python bytecode interpreter and generally does not know how to handle calls to C/C++/CUDA functions that are bound to Python.
8+
- Dynamo and non-strict tracing have trouble tracing through a function and you want it to be ignored by `torch.compile`.
9+
10+
Please see [the Python custom ops tutorial](https://pytorch.org/tutorials/advanced/python_custom_ops.html#python-custom-ops-tutorial)for more details on how to wrap a Python function into a `torch.compile`-understood custom operator.
11+
12+
For more advanced use cases, you may wish to use our C++ Custom Operator API; please see [here](https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html) for more information.

0 commit comments

Comments
 (0)