Skip to content

Commit aa6e701

Browse files
Merge pull request jax-ml#27827 from apaszke:mgpu-docs
PiperOrigin-RevId: 745119028
2 parents 73ecf0b + 511f782 commit aa6e701

File tree

4 files changed

+270
-1
lines changed

4 files changed

+270
-1
lines changed
Lines changed: 99 additions & 0 deletions
Loading

docs/pallas/gpu/index.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
Pallas:Mosaic GPU
2+
=================
3+
Backend specific documentation for the Mosaic GPU backend.
4+
5+
.. toctree::
6+
:caption: Reference documentation
7+
:maxdepth: 2
8+
9+
reference
10+
11+
.. toctree::
12+
:caption: Guides
13+
:maxdepth: 2
14+

docs/pallas/gpu/reference.md

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Writing Mosaic GPU kernels with Pallas
2+
3+
This page is a reference for the most important features of the Pallas:MGPU backend.
4+
It's not a tutorial and as such we do not expect everyone to read it top to bottom.
5+
Still, it is worth going over
6+
just to familiarise yourself with some patterns you can find in other tutorials.
7+
8+
In the following examples, we're going to assume the following imports are in scope:
9+
```python
10+
import jax.experimental.pallas as pl
11+
import jax.experimental.pallas.mosaic_gpu as plgpu
12+
```
13+
14+
## What is a GPU?
15+
16+
Technically, the NVIDIA GPU architecture looks as follows: the GPU is partitioned into
17+
_streaming multiprocessors_ (SMs). The way this manifests in the CUDA programming model
18+
is that each _CUDA thread block_ (or CTA) is scheduled on exactly one SM, but multiple
19+
blocks can be scheduled onto a single SM at a time.
20+
21+
Each SM contains a chunk of fast memory called _shared memory_ (SMEM) and 4 subdivisions,
22+
each containing a _warp scheduler_ and compute units (ALU, TensorCore, ...).
23+
This is also reflected in the CUDA programs: each _warp_ (a group of consecutive 32 CUDA
24+
threads in a block) is assigned to one of those subdivisions in a round-robin fashion.
25+
Similarly to blocks, each warp is assigned to exactly one subdivision (it never migrates),
26+
but multiple warps can be assigned to the same SM subdivision. At each clock cycle, the
27+
warp scheduler from each subdivision tries to select one of its resident warps to execute
28+
the next instruction.
29+
30+
![A diagram of one SM](../../_static/pallas/gpu/nvidia_sm.svg)
31+
32+
Going further, recent CUDA versions also outline the concept of a _warpgroup_, which are
33+
4 consecutive warps. Knowing how the hardware looks like, we can see where this is comming
34+
from: 4 consecutive warps occupy the 4 quarters of an SM and let us issue instructions
35+
that utilize the whole SM.
36+
37+
> A GPU can be viewed in many different ways and in here we want to focus on a slightly
38+
simplified model that is very TensorCore-centric. This should help you navigate the
39+
complexities of writing kernels involving the TensorCore, but keep in mind that the
40+
real picture is more complicated.
41+
42+
For our purposes, TensorCore operations have grown so big that it no longer makes much
43+
sense to follow the CUDA model. As such, to us, a GPU is a collection of single-threaded cores
44+
(SMs) with one thread of Pallas:MGPU corresponding to a CUDA warpgroup. In this model, each
45+
operation you perform in the kernel occupies the whole CUDA warpgroup, and its constituent
46+
warps always run in lockstep (modulo the jitter from hardware scheduling) and never take
47+
different paths through control flow (with the small exception of `core_map` that we will
48+
discuss later). One notable addition here is that we still allow you to co-schedule multiple
49+
of those Pallas-level threads on the same SM so that they can cooperate and communicate
50+
through shared memory (we relize that by putting them in the same CUDA block).
51+
52+
> This is very similar to a programming model popularized by [Triton](https://triton-lang.org/),
53+
but as you will see there are a few differences. Mosaic GPU tends to be more low level,
54+
which usually means you will have to put in more work, but it also puts you more in control.
55+
In our view both approaches have their merits and we encourage you to pick the backend that
56+
suits your needs the best! Pallas supports and will continue to support Triton as an alternative
57+
GPU backend.
58+
59+
### In-order execution & using multiple hardware units
60+
61+
Unlike more complicated CPU architectures GPU only support in-order execution. That, however,
62+
does not mean that at any given time only a single instruction is running! Each SM quarter
63+
has multiple independent functional units: TensorCore, Arithmetic logic unit (ALU),
64+
Load/Store (LSU), Special function unit (SFU). If the first instruction targets one of the
65+
units and is followed by another one (that does not use the result of the first one), then the
66+
warp scheduler can issue the second one before the first one completes. This is often referred
67+
to as instruction-level parallelism (ILP) and is a common theme in modern TensorCore kernels:
68+
TensorCore operations are so big and take so many cycles to complete, that it is a waste to not
69+
try to use other units in the meantime.
70+
71+
To extend this even further, we can take advantage of this hardware-unit-level parallelism by
72+
allowing multiple Pallas threads (warpgroups) to run concurrently. If one of the threads primarily
73+
occupies the ALU, while another one primarily issues TensorCore related instructions, we can
74+
take advantage of the efficient context switching built into the warp schedulers to keep both
75+
units busy. This is one of the core idea behind algorithms such as [FlashAttention 3](https://arxiv.org/abs/2407.08608)
76+
or [CUTLASS ping-pong matmul kernels](https://pytorch.org/blog/cutlass-ping-pong-gemm-kernel/).
77+
78+
For more information on how warp scheduling and instruction issue works, we recommend reading
79+
[Analyzing Modern NVIDIA GPU cores](https://arxiv.org/abs/2503.20481).
80+
81+
## Array layouts and reference transforms
82+
83+
TODO
84+
85+
## MMA (TensorCore)
86+
87+
In this section, we focus on how Pallas:MGPU kernels can utilize the TensorCore unit.
88+
NVIDIA continues to change the programming interface of the TensorCore significantly
89+
between different hardware generations, which is why the lowest-level interfaces
90+
differ in Pallas:MGPU as well.
91+
92+
### Hopper (`wgmma`)
93+
94+
TODO
95+
96+
### Blackwell (`tcgen05`)
97+
98+
TODO
99+
100+
## Using `core_map`
101+
102+
TODO
103+
104+
## Synchronization structures and primitives
105+
106+
### `commit_smem`
107+
108+
TODO
109+
110+
### `Barrier`
111+
112+
This is essentially a thin wrapper around an array of PTX `mbarrier` types and is
113+
passed in as a reference. All functions involving barriers expect to only get a single
114+
barrier argument, and so if the reference contains multiple, you have to extract one
115+
of them explicitly using `barriers.at[index]`.
116+
117+
`Barrier`s are always allocated in SMEM and as such have relatively low overheads.
118+
There are three primary use cases that require the use of `Barrier`s:
119+
120+
1. Awaiting asynchronous GMEM-to-SMEM copies
121+
122+
TODO
123+
124+
2. Cross-warpgroup synchronization
125+
126+
TODO
127+
128+
3. Awaiting `tcgen05` TensorCore instructions
129+
130+
TODO
131+
132+
### `ClusterBarrier`
133+
134+
TODO
135+
136+
### `Semaphore`
137+
138+
TODO
139+
140+
## Asynchronous copies
141+
142+
TODO
143+
144+
## Inline Mosaic GPU
145+
146+
TODO
147+
148+
## Compiler parameters
149+
150+
TODO

docs/pallas/index.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,17 @@ See also the :class:`jax.experimental.pallas` module API documentation.
2626

2727

2828
.. toctree::
29-
:caption: Platform Features
29+
:caption: TPU backend guide
3030
:maxdepth: 2
3131

3232
tpu/index
3333

34+
.. toctree::
35+
:caption: Mosaic GPU backend guide
36+
:maxdepth: 2
37+
38+
gpu/index
39+
3440
.. toctree::
3541
:caption: Design Notes
3642
:maxdepth: 2

0 commit comments

Comments
 (0)