|
94 | 94 | # ``optimizer.zero_grad(set_to_none=True)``. |
95 | 95 |
|
96 | 96 | ############################################################################### |
97 | | -# Fuse pointwise operations |
| 97 | +# Fuse operations |
98 | 98 | # ~~~~~~~~~~~~~~~~~~~~~~~~~ |
99 | | -# Pointwise operations (elementwise addition, multiplication, math functions - |
100 | | -# ``sin()``, ``cos()``, ``sigmoid()`` etc.) can be fused into a single kernel |
101 | | -# to amortize memory access time and kernel launch time. |
102 | | -# |
103 | | -# `PyTorch JIT <https://pytorch.org/docs/stable/jit.html>`_ can fuse kernels |
104 | | -# automatically, although there could be additional fusion opportunities not yet |
105 | | -# implemented in the compiler, and not all device types are supported equally. |
106 | | -# |
107 | | -# Pointwise operations are memory-bound, for each operation PyTorch launches a |
108 | | -# separate kernel. Each kernel loads data from the memory, performs computation |
109 | | -# (this step is usually inexpensive) and stores results back into the memory. |
110 | | -# |
111 | | -# Fused operator launches only one kernel for multiple fused pointwise ops and |
112 | | -# loads/stores data only once to the memory. This makes JIT very useful for |
113 | | -# activation functions, optimizers, custom RNN cells etc. |
| 99 | +# Pointwise operations such as elementwise addition, multiplication, and math |
| 100 | +# functions like `sin()`, `cos()`, `sigmoid()`, etc., can be combined into a |
| 101 | +# single kernel. This fusion helps reduce memory access and kernel launch times. |
| 102 | +# Typically, pointwise operations are memory-bound; PyTorch eager-mode initiates |
| 103 | +# a separate kernel for each operation, which involves loading data from memory, |
| 104 | +# executing the operation (often not the most time-consuming step), and writing |
| 105 | +# the results back to memory. |
| 106 | +# |
| 107 | +# By using a fused operator, only one kernel is launched for multiple pointwise |
| 108 | +# operations, and data is loaded and stored just once. This efficiency is |
| 109 | +# particularly beneficial for activation functions, optimizers, and custom RNN cells etc. |
| 110 | +# |
| 111 | +# PyTorch 2 introduces a compile-mode facilitated by TorchInductor, an underlying compiler |
| 112 | +# that automatically fuses kernels. TorchInductor extends its capabilities beyond simple |
| 113 | +# element-wise operations, enabling advanced fusion of eligible pointwise and reduction |
| 114 | +# operations for optimized performance. |
114 | 115 | # |
115 | 116 | # In the simplest case fusion can be enabled by applying |
116 | | -# `torch.jit.script <https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script>`_ |
| 117 | +# `torch.compile <https://pytorch.org/docs/stable/generated/torch.compile.html>`_ |
117 | 118 | # decorator to the function definition, for example: |
118 | 119 |
|
119 | | -@torch.jit.script |
120 | | -def fused_gelu(x): |
| 120 | +@torch.compile |
| 121 | +def gelu(x): |
121 | 122 | return x * 0.5 * (1.0 + torch.erf(x / 1.41421)) |
122 | 123 |
|
123 | 124 | ############################################################################### |
124 | 125 | # Refer to |
125 | | -# `TorchScript documentation <https://pytorch.org/docs/stable/jit.html>`_ |
| 126 | +# `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_ |
126 | 127 | # for more advanced use cases. |
127 | 128 |
|
128 | 129 | ############################################################################### |
|
0 commit comments