|
| 1 | +[](https://github.com/intel/intel-xpu-backend-for-triton/actions/workflows/build-test.yml) |
| 2 | +[](https://github.com/intel/intel-xpu-backend-for-triton/actions/workflows/nightly-wheels.yml) |
| 3 | + |
| 4 | +# Intel® XPU Backend for Triton\* |
| 5 | + |
| 6 | +This is the development repository of Intel® XPU Backend for Triton\*, a new [Triton](https://github.com/triton-lang/triton) backend for Intel GPUs. |
| 7 | +Intel® XPU Backend for Triton\* is a out of tree backend module for [Triton](https://github.com/triton-lang/triton) used to provide best-in-class performance and productivity on any Intel GPUs for [PyTorch](https://github.com/pytorch/pytorch) and standalone usage. |
| 8 | + |
| 9 | +# Compatibility |
| 10 | + |
| 11 | +* Operating systems: |
| 12 | + * [Ubuntu 22.04](http://releases.ubuntu.com/22.04) |
| 13 | +* GPU Cards: |
| 14 | + * [Intel® Data Center GPU Max Series](https://www.intel.com/content/www/us/en/products/details/discrete-gpus/data-center-gpu/max-series.html) |
| 15 | + * [Intel® Data Center Flex Series](https://www.intel.com/content/www/us/en/products/details/discrete-gpus/data-center-gpu/flex-series.html) |
| 16 | + * [Intel® Arc A770](https://www.intel.com/content/www/us/en/products/sku/229151/intel-arc-a770-graphics-16gb/specifications.html) |
| 17 | +* GPU Drivers: |
| 18 | + * Latest [Long Term Support (LTS) Release](https://dgpu-docs.intel.com/driver/installation.html) |
| 19 | + * Latest [Rolling Release](https://dgpu-docs.intel.com/driver/installation-rolling.html) |
| 20 | +* Toolchain: |
| 21 | + * Latest [Intel® Deep Learning Essentials](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.html?packages=dl-essentials&dl-lin=offline&dl-essentials-os=linux) |
| 22 | + |
| 23 | +Note that Intel® XPU Backend for Triton\* is not compatible with Intel® Extension for PyTorch\* and Intel® oneAPI Base Toolkit\*. |
| 24 | + |
| 25 | +# Quick Installation |
| 26 | + |
| 27 | +## Prerequisites |
| 28 | + |
| 29 | +1. Latest [Rolling Release](https://dgpu-docs.intel.com/driver/installation-rolling.html) or [Long Term Support Release](https://dgpu-docs.intel.com/driver/installation.html) of GPU driver |
| 30 | +2. Latest release of [Intel® Deep Learning Essentials](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.html?packages=dl-essentials&dl-lin=offline&dl-essentials-os=linux) |
| 31 | + |
| 32 | +## Install PyTorch and Triton from nightly wheels |
| 33 | + |
| 34 | +Currently, Intel® XPU Backend for Triton\* requires a special version of PyTorch and both can be installed from nightly wheels. |
| 35 | +Navigate to the [nightly wheels workflow](https://github.com/intel/intel-xpu-backend-for-triton/actions/workflows/nightly-wheels.yml), |
| 36 | +select the most recent successful run on the top of the page and download an artifact for the corresponding Python version. |
| 37 | +Extract the archive and in the extracted directory execute: |
| 38 | + |
| 39 | +```shell |
| 40 | +pip install torch-*.whl triton-*.whl |
| 41 | +``` |
| 42 | + |
| 43 | +Before using Intel® XPU Backend for Triton\* you need to initialize the toolchain. |
| 44 | +The default location is `/opt/intel/oneapi` (if installed as a `root` user) or `~/intel/oneapi` (if installed as a regular user). |
| 45 | + |
| 46 | +```shell |
| 47 | +# replace /opt/intel/oneapi with the actual location of Intel® Deep Learning Essentials |
| 48 | +source /opt/intel/oneapi/setvars.sh |
| 49 | +``` |
| 50 | + |
| 51 | +# Install from source |
| 52 | + |
| 53 | +## Prerequisites |
| 54 | + |
| 55 | +1. Latest [Rolling Release](https://dgpu-docs.intel.com/driver/installation-rolling.html) or [Long Term Support Release](https://dgpu-docs.intel.com/driver/installation.html) of GPU driver |
| 56 | +2. Latest release of [Intel® Deep Learning Essentials](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.html?packages=dl-essentials&dl-lin=offline&dl-essentials-os=linux) |
| 57 | + |
| 58 | +## Compile PyTorch and Triton from source |
| 59 | + |
| 60 | +Currently, Intel® XPU Backend for Triton\* requires a special version of PyTorch and both need to be compiled at the same time. |
| 61 | + |
| 62 | +Before compiling PyTorch and Intel® XPU Backend for Triton\* you need to initialize the toolchain. |
| 63 | +The default location is `/opt/intel/oneapi` (if installed as a `root` user) or `~/intel/oneapi` (if installed as a regular user). |
| 64 | + |
| 65 | +```shell |
| 66 | +# replace /opt/intel/oneapi with the actual location of Intel® Deep Learning Essentials |
| 67 | +source /opt/intel/oneapi/setvars.sh |
| 68 | +``` |
| 69 | + |
| 70 | +Clone this repository: |
| 71 | + |
| 72 | +```shell |
| 73 | +git clone https://github.com/intel/intel-xpu-backend-for-triton.git |
| 74 | +cd intel-xpu-backend-for-triton |
| 75 | +``` |
| 76 | + |
| 77 | +To avoid potential conflicts with installed packages it is recommended to create and activate a new Python virtual environment: |
| 78 | + |
| 79 | +```shell |
| 80 | +python -m venv .venv --prompt triton |
| 81 | +source .venv/bin/activate |
| 82 | +``` |
| 83 | + |
| 84 | +Compile and install PyTorch: |
| 85 | + |
| 86 | +```shell |
| 87 | +scripts/install-pytorch.sh --source |
| 88 | +``` |
| 89 | + |
| 90 | +Compile and install Intel® XPU Backend for Triton\*: |
| 91 | + |
| 92 | +```shell |
| 93 | +scripts/compile-triton.sh |
| 94 | +``` |
| 95 | + |
| 96 | +# Building with a custom LLVM |
| 97 | + |
| 98 | +Triton uses LLVM to generate code for GPUs and CPUs. Normally, the Triton build |
| 99 | +downloads a prebuilt LLVM, but you can also build LLVM from source and use that. |
| 100 | + |
| 101 | +LLVM does not have a stable API, so the Triton build will not work at an |
| 102 | +arbitrary LLVM version. |
| 103 | + |
| 104 | +1. Find the version of LLVM that Triton builds against. |
| 105 | +Check `cmake/llvm-hash.txt` to see the current version. |
| 106 | + |
| 107 | +2. Checkout LLVM at this revision to the directory `llvm`, |
| 108 | +which must be in the same directory as `intel-xpu-backend-for-triton`: |
| 109 | + |
| 110 | +3. In the directory `intel-xpu-backend-for-triton`, build Triton with custom LLVM: |
| 111 | + |
| 112 | + ```shell |
| 113 | + ./scripts/compile-triton.sh --llvm --triton |
| 114 | + ``` |
| 115 | + |
| 116 | +# Tips for building |
| 117 | + |
| 118 | +- Set `TRITON_BUILD_WITH_CLANG_LLD=true` as an environment variable to use clang |
| 119 | + and lld. lld in particular results in faster builds. |
| 120 | + |
| 121 | +- Set `TRITON_BUILD_WITH_CCACHE=true` to build with ccache. |
| 122 | + |
| 123 | +- Set `TRITON_HOME=/some/path` to change the location of the `.triton` |
| 124 | + directory where Triton's cache is located and downloads are stored |
| 125 | + during the build. By default, this is the user's home directory. It |
| 126 | + can be changed anytime. |
| 127 | + |
| 128 | +- Pass `--no-build-isolation` to `pip install` to make nop builds faster. |
| 129 | + Without this, every invocation of `pip install` uses a different symlink to |
| 130 | + cmake, and this forces ninja to rebuild most of the `.a` files. |
| 131 | + |
| 132 | +- VSCcode IntelliSense has some difficulty figuring out how to build Triton's C++ |
| 133 | + (probably because, in our build, users don't invoke cmake directly, but |
| 134 | + instead use setup.py). Teach vscode how to compile Triton as follows. |
| 135 | + |
| 136 | + - Do a local build. Run command `pip install -e python` |
| 137 | + - Get the full path to the `compile_commands.json` file produced by the build: |
| 138 | + `find python/build -name 'compile_commands.json' | xargs readlink -f`. |
| 139 | + You might get a full path similar to `/Users/{username}/triton/python/build/cmake.macosx-11.1-arm64-cpython-3.12/compile_commands.json` |
| 140 | + - In vscode, install the |
| 141 | + [C/C++ |
| 142 | + extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode.cpptools), |
| 143 | + then open the command palette (`Shift + Command + P` on Mac, or `Shift + |
| 144 | + Ctrl + P` on Windows/Linux) and open `C/C++: Edit Configurations (UI)`. |
| 145 | + - Open "Advanced Settings" and paste the full path to |
| 146 | + `compile_commands.json` into the "Compile Commands" textbox. |
| 147 | + |
| 148 | +# Running tests |
| 149 | + |
| 150 | +There currently isn't a turnkey way to run all the Triton tests, but you can |
| 151 | +follow the following recipe. |
| 152 | +
|
| 153 | +```shell |
| 154 | +scripts/test-triton.sh |
| 155 | +``` |
| 156 | +
|
| 157 | +# Tips for hacking |
| 158 | +
|
| 159 | +For detailed instructions on how to debug Triton's frontend, please refer to this [tutorial](https://triton-lang.org/main/programming-guide/chapter-3/debugging.html). The following includes additional tips for hacking on Triton's backend. |
| 160 | +
|
| 161 | +**Helpful environment variables** |
| 162 | +
|
| 163 | +- `MLIR_ENABLE_DUMP=1` dumps the IR before every MLIR pass Triton runs, for all |
| 164 | + kernels. Use `MLIR_ENABLE_DUMP=kernelName` to dump for a specific kernel only. |
| 165 | + - Triton cache can interfere with the dump. In cases where `MLIR_ENABLE_DUMP=1` does not work, try cleaning your triton cache: `rm -r ~/.triton/cache/*` |
| 166 | +- `LLVM_IR_ENABLE_DUMP=1` dumps the IR before every pass run over the LLVM IR. |
| 167 | +- `TRITON_INTERPRET=1` uses the Triton interpreter instead of running on the |
| 168 | + GPU. You can insert Python breakpoints in your kernel code! |
| 169 | +- `TRITON_ENABLE_LLVM_DEBUG=1` passes `-debug` to LLVM, printing a lot of |
| 170 | + debugging information to stdout. If this is too noisy, run with just |
| 171 | + `TRITON_LLVM_DEBUG_ONLY` instead to limit the output. |
| 172 | +
|
| 173 | + An alternative way to reduce output noisiness is running with |
| 174 | + `LLVM_IR_ENABLE_DUMP=1`, extract the IR before the LLVM pass of interest, and |
| 175 | + then run LLVM's `opt` standalone, perhaps passing `-debug-only=foo` on the |
| 176 | + command line. |
| 177 | +- `TRITON_LLVM_DEBUG_ONLY=<comma-separated>` is the equivalent of LLVM's |
| 178 | + `-debug-only` command-line option. This limits the LLVM debug output to |
| 179 | + specific pass or component names (which are specified using `#define |
| 180 | + DEBUG_TYPE` throughout LLVM and Triton) in order to allow the debug output to |
| 181 | + be less noisy. `TRITON_LLVM_DEBUG_ONLY` allows for one or more comma |
| 182 | + separated values to be specified (eg |
| 183 | + `TRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions` or |
| 184 | + `TRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions,regalloc"`). |
| 185 | +- `USE_IR_LOC={ttir,ttgir}` reparses the IR such that the location information |
| 186 | + will be the line number of the IR file with that particular extension, |
| 187 | + instead of line number of the python file. This can provide a direct mapping |
| 188 | + from the IR to llir/ptx. When used with performance tools, it can provide a |
| 189 | + breakdown on IR instructions. |
| 190 | +- `TRITON_PRINT_AUTOTUNING=1` prints out the best autotuning config and total time |
| 191 | + spent for each kernel after autotuning is complete. |
| 192 | +- `DISABLE_LLVM_OPT` will disable llvm optimizations for make_llir and make_ptx |
| 193 | + if its value is true when parsing as Bool. Otherwise, it will be parsed as a list |
| 194 | + of flags to disable llvm optimizations. One usage case is |
| 195 | + `DISABLE_LLVM_OPT="disable-lsr"` |
| 196 | + Loop strength reduction is known to cause up to 10% performance changes for |
| 197 | + certain kernels with register pressure. |
| 198 | +- `TRITON_ALWAYS_COMPILE=1` forces to compile kernels regardless of cache hit. |
| 199 | +- `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass. |
| 200 | +- `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass. |
| 201 | +- `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma). |
| 202 | +- `MLIR_ENABLE_REMARK` enables the performance warnings that are emitted as remarks. |
| 203 | +
|
| 204 | +# Usage Guide |
| 205 | +
|
| 206 | +## Code Modifications |
| 207 | +Intel® XPU Backend for Triton\* requires a special version of PyTorch that can be built from sources or installed from nightly wheels. |
| 208 | +
|
| 209 | +1. Add `import torch` for xpu support. |
| 210 | +2. Put the tensor and models to XPU by calling `to('xpu')`. |
| 211 | +
|
| 212 | +This repository contains modified [tutorials](https://github.com/intel/intel-xpu-backend-for-triton/tree/main/python/tutorials) that must be used with Intel® XPU Backend for Triton\*. |
| 213 | +
|
| 214 | +The following examples show modifications for the user code. |
| 215 | +
|
| 216 | +### Example 1 : Triton Kernel |
| 217 | +
|
| 218 | +This example is a modified version of [Vector Add](https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html#vector-addition) triton kernel. Please refer to [Vector Add](https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html#vector-addition) for detailed comments and illustration about the code semantics. |
| 219 | +
|
| 220 | +Comparing to the original code, the following code modifies: |
| 221 | +
|
| 222 | +```Python |
| 223 | +import torch |
| 224 | +import triton |
| 225 | +import triton.language as tl |
| 226 | +
|
| 227 | +
|
| 228 | +@triton.jit |
| 229 | +def add_kernel( |
| 230 | + x_ptr, |
| 231 | + y_ptr, |
| 232 | + output_ptr, |
| 233 | + n_elements, |
| 234 | + BLOCK_SIZE: tl.constexpr, |
| 235 | +): |
| 236 | + pid = tl.program_id(axis=0) |
| 237 | + block_start = pid * BLOCK_SIZE |
| 238 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 239 | + mask = offsets < n_elements |
| 240 | + x = tl.load(x_ptr + offsets, mask=mask) |
| 241 | + y = tl.load(y_ptr + offsets, mask=mask) |
| 242 | + output = x + y |
| 243 | + tl.store(output_ptr + offsets, output, mask=mask) |
| 244 | +
|
| 245 | +def add(x: torch.Tensor, y: torch.Tensor): |
| 246 | + # Put the tensor to xpu |
| 247 | + output = torch.empty_like(x).xpu() |
| 248 | + assert x.is_xpu and y.is_xpu and output.is_xpu |
| 249 | + n_elements = output.numel() |
| 250 | + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) |
| 251 | + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) |
| 252 | +
|
| 253 | + return output |
| 254 | +
|
| 255 | +# For manual_seed, needs to use API for XPU |
| 256 | +torch.xpu.manual_seed(0) |
| 257 | +size = 512 |
| 258 | +# For tensors, needs to be put on XPU |
| 259 | +x = torch.rand(size, device='xpu') |
| 260 | +y = torch.rand(size, device='xpu') |
| 261 | +output_torch = x + y |
| 262 | +output_triton = add(x, y) |
| 263 | +print(output_torch) |
| 264 | +print(output_triton) |
| 265 | +print( |
| 266 | + f'The maximum difference between torch and triton is ' |
| 267 | + f'{torch.max(torch.abs(output_torch - output_triton))}' |
| 268 | +) |
| 269 | +``` |
| 270 | +
|
| 271 | +### Example 2 : End-to-End Model |
| 272 | +Triton is transparent for end-to-end models. One could easily use `torch.compile` with `inductor` as backend by default. It will automatically generates triton kernel and gets benefit from it. |
| 273 | +
|
| 274 | +```Python |
| 275 | +import torch |
| 276 | +from torch._dynamo.testing import rand_strided |
| 277 | +
|
| 278 | +from torch.nn import * |
| 279 | +class simpleModel(torch.nn.Module): |
| 280 | + def __init__(self): |
| 281 | + super().__init__() |
| 282 | + # tensors inside model should be on xpu |
| 283 | + self.y = rand_strided((32, 8), (8, 1), device='xpu:0', dtype=torch.float32) |
| 284 | +
|
| 285 | + def forward(self, x): |
| 286 | + z = x + self.y |
| 287 | + return z |
| 288 | +
|
| 289 | +# tensors passed to the model should be on xpu |
| 290 | +x = rand_strided((32, 8), (8, 1), device='xpu:0', dtype=torch.float32) |
| 291 | +xpu_model = simpleModel() |
| 292 | +# Call torch.compile for optimization |
| 293 | +optimized_mod = torch.compile(xpu_model) |
| 294 | +
|
| 295 | +graph_result = optimized_mod(x) |
| 296 | +``` |
| 297 | +
|
| 298 | +## Performance Analysis Guide |
| 299 | +
|
| 300 | +There are several ways of doing performance analysis. |
| 301 | +We recommend using `torch.profiler` for end-to-end performance analysis and using Intel® VTune™ Profiler for more detailed kernel analysis. |
| 302 | +Note that the user needs to explicitly set `TRITON_XPU_PROFILE=1` when the user needs to enable kernel profiling. |
| 303 | +
|
| 304 | +```Bash |
| 305 | +export TRITON_XPU_PROFILE=1 |
| 306 | +``` |
| 307 | +
|
| 308 | +# Contributing |
| 309 | +
|
| 310 | +Community contributions are more than welcome, whether it be to fix bugs or to add new features at [github](https://github.com/intel/intel-xpu-backend-for-triton). For more detailed instructions, please visit our [contributor's guide](https://github.com/intel/intel-xpu-backend-for-triton/blob/main/CONTRIBUTING.md). |
| 311 | + |
| 312 | +## License |
| 313 | + |
| 314 | +_MIT License_. As found in [LICENSE](https://github.com/intel/intel-xpu-backend-for-triton/blob/main/LICENSE) file. |
| 315 | + |
| 316 | + |
| 317 | +## Security |
| 318 | + |
| 319 | +See Intel's [Security Center](https://www.intel.com/content/www/us/en/security-center/default.html) |
| 320 | +for information on how to report a potential security issue or vulnerability. |
| 321 | +
|
| 322 | +See also: [Security Policy](https://github.com/intel/intel-xpu-backend-for-triton/blob/main/SECURITY.md). |
0 commit comments