|
| 1 | +--- |
| 2 | +file_format: mystnb |
| 3 | +kernelspec: |
| 4 | + name: python3 |
| 5 | +mystnb: |
| 6 | + execution_timeout: 30 |
| 7 | + execution_show_tb: True |
| 8 | + merge_streams: True |
| 9 | +--- |
| 10 | + |
| 11 | +```{code-cell} |
| 12 | +:tags: [remove-cell] |
| 13 | +import torch |
| 14 | +
|
| 15 | +import header_code |
| 16 | +
|
| 17 | +torch._logging.set_logs(graph_breaks=True) |
| 18 | +``` |
| 19 | + |
| 20 | +# Common Graph Breaks |
| 21 | + |
| 22 | +Below are some common graph breaks and some workarounds. |
| 23 | + |
| 24 | +## Incorrect Code |
| 25 | +Your code might contain errors (meaning it doesn't execute even without `torch.compile`). In the example below, there's a typo in the `torch.sin` call due to an extra argument. **Always disable `torch.compile` to check if the code runs correctly.** |
| 26 | + |
| 27 | + |
| 28 | +```{code-cell} |
| 29 | +@torch.compile |
| 30 | +def fn(x): |
| 31 | + y = torch.sin(x, x) |
| 32 | + return y |
| 33 | +
|
| 34 | +try: |
| 35 | + fn(torch.ones(3, 3)) |
| 36 | +except Exception as e: |
| 37 | + pass |
| 38 | +``` |
| 39 | + |
| 40 | +Dynamo makes a best-effort attempt to hint if a graph break is caused by your code. |
| 41 | +But it can still sometimes be difficult to tell from the logs if the graph break is caused by an error in your code, |
| 42 | +is a more complicated graph break, or is a `torch.compile` bug. In order to differentiate, we recommend trying to run your code without `torch.compile` to see if you still get the error reported by the graph break. |
| 43 | + |
| 44 | +## Data-dependent operations |
| 45 | + |
| 46 | +`torch.compile` graph breaks on data-dependent operations such as data-dependent control flow (if-statements, loops with tensors) and direct tensor data accesses (`.item`, `.data_ptr`). |
| 47 | + |
| 48 | +```{code-cell} |
| 49 | +@torch.compile |
| 50 | +def fn(x): |
| 51 | + y = x.sum() |
| 52 | + if y > 0: |
| 53 | + return x + y.item() |
| 54 | + return x - y.item() |
| 55 | +
|
| 56 | +print(fn(torch.ones(3, 3))) |
| 57 | +``` |
| 58 | + |
| 59 | +The general workaround for these graph breaks is to avoid doing data-dependent operations. Some specific workarounds are: |
| 60 | + |
| 61 | +- If your control flow doesn't actually depend on data values, consider modifying your code to perform control flow on constants. |
| 62 | + |
| 63 | + |
| 64 | +```{code-cell} |
| 65 | +# old |
| 66 | +x = torch.randn(3, 3) |
| 67 | +@torch.compile |
| 68 | +def fn(y): |
| 69 | + if x.sum() > 0: |
| 70 | + return y + x |
| 71 | + else: |
| 72 | + return y - x |
| 73 | +
|
| 74 | +print(fn(torch.ones(3, 3))) |
| 75 | +``` |
| 76 | + |
| 77 | +```{code-cell} |
| 78 | +# new |
| 79 | +x = torch.randn(3, 3) |
| 80 | +cond = (x.sum() > 0).item() |
| 81 | +@torch.compile |
| 82 | +def fn(y): |
| 83 | + if cond: |
| 84 | + return y + x |
| 85 | + else: |
| 86 | + return y - x |
| 87 | +
|
| 88 | +print(fn(torch.ones(3, 3))) |
| 89 | +``` |
| 90 | + |
| 91 | +- Use higher-order ops like {ref}`cond` in place of data-dependent control flow |
| 92 | + |
| 93 | + |
| 94 | +```{code-cell} |
| 95 | +# old |
| 96 | +@torch.compile |
| 97 | +def fn(x): |
| 98 | + if x.sum() > 0: |
| 99 | + return x + 1 |
| 100 | + return x - 1 |
| 101 | +
|
| 102 | +print(fn(torch.ones(3, 3))) |
| 103 | +``` |
| 104 | + |
| 105 | +```{code-cell} |
| 106 | +# new |
| 107 | +@torch.compile |
| 108 | +def fn(x): |
| 109 | + return torch.cond( |
| 110 | + x.sum() > 0, |
| 111 | + lambda x: x + 1, |
| 112 | + lambda x: x - 1, |
| 113 | + (x,), |
| 114 | + ) |
| 115 | +
|
| 116 | +print(fn(torch.ones(3, 3))) |
| 117 | +``` |
| 118 | + |
| 119 | +- If you have a `.item()` call, try `torch._dynamo.config.capture_scalar_outputs = True` |
| 120 | +or `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1`. |
| 121 | +- Wrap problematic parts of the function in a custom operator |
| 122 | + |
| 123 | +## Printing and logging |
| 124 | + |
| 125 | +Printing/logging/issuing warnings will result in a graph break. |
| 126 | +You can try working around this by using `torch._dynamo.config.reorderable_logging_functions`. |
| 127 | +This config is used to reorder logging functions so that they are called at the end of the |
| 128 | +traced function, thus avoiding a graph break. |
| 129 | +However, the logged contents may differ if, for example, a mutation occurs. |
| 130 | + |
| 131 | + |
| 132 | +```{code-cell} |
| 133 | +torch._dynamo.config.reorderable_logging_functions.add(print) |
| 134 | +
|
| 135 | +@torch.compile |
| 136 | +def fn(x): |
| 137 | + x += 1 |
| 138 | + print("log!") |
| 139 | + return torch.sin(x) |
| 140 | +
|
| 141 | +print(fn(torch.ones(3, 3))) |
| 142 | +``` |
0 commit comments