|
| 1 | +# Writing Mosaic GPU kernels with Pallas |
| 2 | + |
| 3 | +This page is a reference for the most important features of the Pallas:MGPU backend. |
| 4 | +It's not a tutorial and as such we do not expect everyone to read it top to bottom. |
| 5 | +Still, it is worth going over |
| 6 | +just to familiarise yourself with some patterns you can find in other tutorials. |
| 7 | + |
| 8 | +In the following examples, we're going to assume the following imports are in scope: |
| 9 | +```python |
| 10 | +import jax.experimental.pallas as pl |
| 11 | +import jax.experimental.pallas.mosaic_gpu as plgpu |
| 12 | +``` |
| 13 | + |
| 14 | +## What is a GPU? |
| 15 | + |
| 16 | +Technically, the NVIDIA GPU architecture looks as follows: the GPU is partitioned into |
| 17 | +_streaming multiprocessors_ (SMs). The way this manifests in the CUDA programming model |
| 18 | +is that each _CUDA thread block_ (or CTA) is scheduled on exactly one SM, but multiple |
| 19 | +blocks can be scheduled onto a single SM at a time. |
| 20 | + |
| 21 | +Each SM contains a chunk of fast memory called _shared memory_ (SMEM) and 4 subdivisions, |
| 22 | +each containing a _warp scheduler_ and compute units (ALU, TensorCore, ...). |
| 23 | +This is also reflected in the CUDA programs: each _warp_ (a group of consecutive 32 CUDA |
| 24 | +threads in a block) is assigned to one of those subdivisions in a round-robin fashion. |
| 25 | +Similarly to blocks, each warp is assigned to exactly one subdivision (it never migrates), |
| 26 | +but multiple warps can be assigned to the same SM subdivision. At each clock cycle, the |
| 27 | +warp scheduler from each subdivision tries to select one of its resident warps to execute |
| 28 | +the next instruction. |
| 29 | + |
| 30 | + |
| 31 | + |
| 32 | +Going further, recent CUDA versions also outline the concept of a _warpgroup_, which are |
| 33 | +4 consecutive warps. Knowing how the hardware looks like, we can see where this is comming |
| 34 | +from: 4 consecutive warps occupy the 4 quarters of an SM and let us issue instructions |
| 35 | +that utilize the whole SM. |
| 36 | + |
| 37 | +> A GPU can be viewed in many different ways and in here we want to focus on a slightly |
| 38 | + simplified model that is very TensorCore-centric. This should help you navigate the |
| 39 | + complexities of writing kernels involving the TensorCore, but keep in mind that the |
| 40 | + real picture is more complicated. |
| 41 | + |
| 42 | +For our purposes, TensorCore operations have grown so big that it no longer makes much |
| 43 | +sense to follow the CUDA model. As such, to us, a GPU is a collection of single-threaded cores |
| 44 | +(SMs) with one thread of Pallas:MGPU corresponding to a CUDA warpgroup. In this model, each |
| 45 | +operation you perform in the kernel occupies the whole CUDA warpgroup, and its constituent |
| 46 | +warps always run in lockstep (modulo the jitter from hardware scheduling) and never take |
| 47 | +different paths through control flow (with the small exception of `core_map` that we will |
| 48 | +discuss later). One notable addition here is that we still allow you to co-schedule multiple |
| 49 | +of those Pallas-level threads on the same SM so that they can cooperate and communicate |
| 50 | +through shared memory (we relize that by putting them in the same CUDA block). |
| 51 | + |
| 52 | +> This is very similar to a programming model popularized by [Triton](https://triton-lang.org/), |
| 53 | + but as you will see there are a few differences. Mosaic GPU tends to be more low level, |
| 54 | + which usually means you will have to put in more work, but it also puts you more in control. |
| 55 | + In our view both approaches have their merits and we encourage you to pick the backend that |
| 56 | + suits your needs the best! Pallas supports and will continue to support Triton as an alternative |
| 57 | + GPU backend. |
| 58 | + |
| 59 | +### In-order execution & using multiple hardware units |
| 60 | + |
| 61 | +Unlike more complicated CPU architectures GPU only support in-order execution. That, however, |
| 62 | +does not mean that at any given time only a single instruction is running! Each SM quarter |
| 63 | +has multiple independent functional units: TensorCore, Arithmetic logic unit (ALU), |
| 64 | +Load/Store (LSU), Special function unit (SFU). If the first instruction targets one of the |
| 65 | +units and is followed by another one (that does not use the result of the first one), then the |
| 66 | +warp scheduler can issue the second one before the first one completes. This is often referred |
| 67 | +to as instruction-level parallelism (ILP) and is a common theme in modern TensorCore kernels: |
| 68 | +TensorCore operations are so big and take so many cycles to complete, that it is a waste to not |
| 69 | +try to use other units in the meantime. |
| 70 | + |
| 71 | +To extend this even further, we can take advantage of this hardware-unit-level parallelism by |
| 72 | +allowing multiple Pallas threads (warpgroups) to run concurrently. If one of the threads primarily |
| 73 | +occupies the ALU, while another one primarily issues TensorCore related instructions, we can |
| 74 | +take advantage of the efficient context switching built into the warp schedulers to keep both |
| 75 | +units busy. This is one of the core idea behind algorithms such as [FlashAttention 3](https://arxiv.org/abs/2407.08608) |
| 76 | +or [CUTLASS ping-pong matmul kernels](https://pytorch.org/blog/cutlass-ping-pong-gemm-kernel/). |
| 77 | + |
| 78 | +For more information on how warp scheduling and instruction issue works, we recommend reading |
| 79 | +[Analyzing Modern NVIDIA GPU cores](https://arxiv.org/abs/2503.20481). |
| 80 | + |
| 81 | +## Array layouts and reference transforms |
| 82 | + |
| 83 | +TODO |
| 84 | + |
| 85 | +## MMA (TensorCore) |
| 86 | + |
| 87 | +In this section, we focus on how Pallas:MGPU kernels can utilize the TensorCore unit. |
| 88 | +NVIDIA continues to change the programming interface of the TensorCore significantly |
| 89 | +between different hardware generations, which is why the lowest-level interfaces |
| 90 | +differ in Pallas:MGPU as well. |
| 91 | + |
| 92 | +### Hopper (`wgmma`) |
| 93 | + |
| 94 | +TODO |
| 95 | + |
| 96 | +### Blackwell (`tcgen05`) |
| 97 | + |
| 98 | +TODO |
| 99 | + |
| 100 | +## Using `core_map` |
| 101 | + |
| 102 | +TODO |
| 103 | + |
| 104 | +## Synchronization structures and primitives |
| 105 | + |
| 106 | +### `commit_smem` |
| 107 | + |
| 108 | +TODO |
| 109 | + |
| 110 | +### `Barrier` |
| 111 | + |
| 112 | +This is essentially a thin wrapper around an array of PTX `mbarrier` types and is |
| 113 | +passed in as a reference. All functions involving barriers expect to only get a single |
| 114 | +barrier argument, and so if the reference contains multiple, you have to extract one |
| 115 | +of them explicitly using `barriers.at[index]`. |
| 116 | + |
| 117 | +`Barrier`s are always allocated in SMEM and as such have relatively low overheads. |
| 118 | +There are three primary use cases that require the use of `Barrier`s: |
| 119 | + |
| 120 | +1. Awaiting asynchronous GMEM-to-SMEM copies |
| 121 | + |
| 122 | +TODO |
| 123 | + |
| 124 | +2. Cross-warpgroup synchronization |
| 125 | + |
| 126 | +TODO |
| 127 | + |
| 128 | +3. Awaiting `tcgen05` TensorCore instructions |
| 129 | + |
| 130 | +TODO |
| 131 | + |
| 132 | +### `ClusterBarrier` |
| 133 | + |
| 134 | +TODO |
| 135 | + |
| 136 | +### `Semaphore` |
| 137 | + |
| 138 | +TODO |
| 139 | + |
| 140 | +## Asynchronous copies |
| 141 | + |
| 142 | +TODO |
| 143 | + |
| 144 | +## Inline Mosaic GPU |
| 145 | + |
| 146 | +TODO |
| 147 | + |
| 148 | +## Compiler parameters |
| 149 | + |
| 150 | +TODO |
0 commit comments