|
| 1 | +""" |
| 2 | +Vector Addition |
| 3 | +=============== |
| 4 | +
|
| 5 | +In this tutorial, you will write a simple vector addition using Triton. |
| 6 | +
|
| 7 | +In doing so, you will learn about: |
| 8 | +
|
| 9 | +* The basic programming model of Triton. |
| 10 | +
|
| 11 | +* The `triton.jit` decorator, which is used to define Triton kernels. |
| 12 | +
|
| 13 | +* The best practices for validating and benchmarking your custom ops against native reference implementations. |
| 14 | +
|
| 15 | +""" |
| 16 | + |
| 17 | +# %% |
| 18 | +# Compute Kernel |
| 19 | +# -------------- |
| 20 | + |
| 21 | +import torch |
| 22 | + |
| 23 | +import triton |
| 24 | +import triton.language as tl |
| 25 | + |
| 26 | +DEVICE = triton.runtime.driver.active.get_active_torch_device() |
| 27 | + |
| 28 | + |
| 29 | +@triton.jit |
| 30 | +def add_kernel(x_ptr, # *Pointer* to first input vector. |
| 31 | + y_ptr, # *Pointer* to second input vector. |
| 32 | + output_ptr, # *Pointer* to output vector. |
| 33 | + n_elements, # Size of the vector. |
| 34 | + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. |
| 35 | + # NOTE: `constexpr` so it can be used as a shape value. |
| 36 | + ): |
| 37 | + # There are multiple 'programs' processing different data. We identify which program |
| 38 | + # we are here: |
| 39 | + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. |
| 40 | + # This program will process inputs that are offset from the initial data. |
| 41 | + # For instance, if you had a vector of length 256 and block_size of 64, the programs |
| 42 | + # would each access the elements [0:64, 64:128, 128:192, 192:256]. |
| 43 | + # Note that offsets is a list of pointers: |
| 44 | + block_start = pid * BLOCK_SIZE |
| 45 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 46 | + # Create a mask to guard memory operations against out-of-bounds accesses. |
| 47 | + mask = offsets < n_elements |
| 48 | + # Load x and y from DRAM, masking out any extra elements in case the input is not a |
| 49 | + # multiple of the block size. |
| 50 | + x = tl.load(x_ptr + offsets, mask=mask) |
| 51 | + y = tl.load(y_ptr + offsets, mask=mask) |
| 52 | + output = x + y |
| 53 | + # Write x + y back to DRAM. |
| 54 | + tl.store(output_ptr + offsets, output, mask=mask) |
| 55 | + |
| 56 | + |
| 57 | +# %% |
| 58 | +# Let's also declare a helper function to (1) allocate the `z` tensor |
| 59 | +# and (2) enqueue the above kernel with appropriate grid/block sizes: |
| 60 | + |
| 61 | + |
| 62 | +def add(x: torch.Tensor, y: torch.Tensor): |
| 63 | + # We need to preallocate the output. |
| 64 | + output = torch.empty_like(x) |
| 65 | + assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE |
| 66 | + n_elements = output.numel() |
| 67 | + # The SPMD launch grid denotes the number of kernel instances that run in parallel. |
| 68 | + # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. |
| 69 | + # In this case, we use a 1D grid where the size is the number of blocks: |
| 70 | + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) |
| 71 | + # NOTE: |
| 72 | + # - Each torch.tensor object is implicitly converted into a pointer to its first element. |
| 73 | + # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. |
| 74 | + # - Don't forget to pass meta-parameters as keywords arguments. |
| 75 | + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) |
| 76 | + # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still |
| 77 | + # running asynchronously at this point. |
| 78 | + return output |
| 79 | + |
| 80 | + |
| 81 | +# %% |
| 82 | +# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness: |
| 83 | + |
| 84 | +torch.manual_seed(0) |
| 85 | +size = 98432 |
| 86 | +x = torch.rand(size, device=DEVICE) |
| 87 | +y = torch.rand(size, device=DEVICE) |
| 88 | +output_torch = x + y |
| 89 | +output_triton = add(x, y) |
| 90 | +print(output_torch.cpu()) |
| 91 | +print(output_triton.cpu()) |
| 92 | +print(f'The maximum difference between torch and triton is ' |
| 93 | + f'{torch.max(torch.abs(output_torch.cpu() - output_triton.cpu()))}') |
0 commit comments