Skip to content

Commit e05afef

Browse files
committed
[Pallas] Pallas documentation cleanup
1 parent 1a3c9c4 commit e05afef

File tree

9 files changed

+88
-76
lines changed

9 files changed

+88
-76
lines changed
File renamed without changes.

docs/pallas/design.md renamed to docs/pallas/design/design.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ Pallas kernels via JAX transformations.
9494

9595
<center>
9696

97-
![Pallas lowering path](../_static/pallas/pallas_flow.png)
97+
![Pallas lowering path](../../_static/pallas/pallas_flow.png)
9898
Visualization of Pallas lowering paths
9999

100100
</center>
@@ -413,10 +413,10 @@ verify the correctness of the Triton and Mosaic compilers.
413413
One could also imagine perturbing the `scan` ordering to simulate the
414414
parallel reads and writes that happen on GPU.
415415

416-
### Examples
416+
### GPU Examples
417417

418-
Note all the following examples are for GPU only. They will require some small
419-
changes to work on TPUs.
418+
Note all the following examples are for GPU only. They will require tweaks to
419+
the block sizes to work on TPUs.
420420

421421
#### `add`
422422

docs/pallas/design/index.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Pallas Design Notes
2+
===================
3+
4+
.. toctree::
5+
:caption: Design
6+
:maxdepth: 2
7+
8+
design
9+
async_note

docs/pallas/grid_blockspec.md

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -44,39 +44,7 @@ For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and
4444
You can also use {func}`jax.experimental.pallas.num_programs` to get the
4545
grid size for a given axis.
4646

47-
Here's an example kernel that uses a `grid` and `program_id`.
48-
49-
```python
50-
>>> import jax
51-
>>> from jax.experimental import pallas as pl
52-
53-
>>> def iota_kernel(o_ref):
54-
... i = pl.program_id(0)
55-
... o_ref[i] = i
56-
57-
```
58-
59-
We now execute it using `pallas_call` with an additional `grid` argument.
60-
61-
```python
62-
>>> def iota(size: int):
63-
... return pl.pallas_call(iota_kernel,
64-
... out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
65-
... grid=(size,), interpret=True)()
66-
>>> iota(8)
67-
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
68-
69-
```
70-
71-
On GPUs, each program is executed in parallel on separate thread blocks.
72-
Thus, we need to think about race conditions on writes to HBM.
73-
A reasonable approach is to write our kernels in such a way that different
74-
programs write to disjoint places in HBM to avoid these parallel writes.
75-
76-
On TPUs, programs are executed in a combination of parallel and sequential
77-
(depending on the architecture) so there are slightly different considerations.
78-
79-
See {ref}`pallas_tpu_noteworthy_properties`.
47+
See {ref}`grids_by_example` for a simple kernel that uses this API.
8048

8149
(pallas_blockspec)=
8250

@@ -131,6 +99,8 @@ shape `x_shape` are computed as in the function `slice_for_invocation`
13199
below:
132100

133101
```python
102+
>>> import jax
103+
>>> from jax.experimental import pallas as pl
134104
>>> def slices_for_invocation(x_shape: tuple[int, ...],
135105
... x_spec: pl.BlockSpec,
136106
... grid: tuple[int, ...],

docs/pallas/index.rst

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ See also the :class:`jax.experimental.pallas` module API documentation.
2222
:maxdepth: 2
2323

2424
quickstart
25-
design
2625
grid_blockspec
2726

2827

@@ -34,9 +33,9 @@ See also the :class:`jax.experimental.pallas` module API documentation.
3433

3534
.. toctree::
3635
:caption: Design Notes
37-
:maxdepth: 1
36+
:maxdepth: 2
3837

39-
async_note
38+
design/index
4039

4140
.. toctree::
4241
:caption: Other

docs/pallas/quickstart.ipynb

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,9 @@
7272
"\n",
7373
"Let's dissect this function a bit. Unlike most JAX functions you've probably written,\n",
7474
"it does not take in `jax.Array`s as inputs and doesn't return any values.\n",
75-
"Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs\n",
76-
"but we are given an `o_ref`, which corresponds to the desired output.\n",
75+
"Instead, it takes in *`Ref`* objects as inputs, which represent mutable buffers in memory.\n",
76+
"Note that we also don't have any outputs but we are given an `o_ref`, which corresponds\n",
77+
"to the desired output.\n",
7778
"\n",
7879
"**Reading from `Ref`s**\n",
7980
"\n",
@@ -150,7 +151,8 @@
150151
"**What's actually happening here?**\n",
151152
"\n",
152153
"Thus far we've described how to think about Pallas kernels but what we've actually\n",
153-
"accomplished is we're writing a function that's executed very close to the compute units.\n",
154+
"accomplished is we're writing a function that's executed very close to the compute units\n",
155+
"since values are loaded into the innermost (fastest) portion of the memory hierarchy.\n",
154156
"\n",
155157
"On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when\n",
156158
"we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM)\n",
@@ -195,6 +197,8 @@
195197
"live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations\n",
196198
"that operate on \"blocks\" of those arrays that can fit in SRAM.\n",
197199
"\n",
200+
"(grids_by_example)=\n",
201+
"\n",
198202
"### Grids by example\n",
199203
"\n",
200204
"To automatically \"carve\" up the inputs and outputs, you provide a `grid` and\n",
@@ -240,7 +244,8 @@
240244
"cell_type": "markdown",
241245
"metadata": {},
242246
"source": [
243-
"We now execute it using `pallas_call` with an additional `grid` argument."
247+
"We now execute it using `pallas_call` with an additional `grid` argument.\n",
248+
"On GPUs, we can call the kernel directly like so:"
244249
]
245250
},
246251
{
@@ -260,6 +265,7 @@
260265
}
261266
],
262267
"source": [
268+
"# GPU version\n",
263269
"def iota(size: int):\n",
264270
" return pl.pallas_call(iota_kernel,\n",
265271
" out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n",
@@ -272,16 +278,9 @@
272278
"cell_type": "markdown",
273279
"metadata": {},
274280
"source": [
275-
"On GPUs, each program is executed in parallel on separate threads.\n",
276-
"Thus, we need to think about race conditions on writes to HBM.\n",
277-
"A reasonable approach is to write our kernels in such a way that different\n",
278-
"programs write to disjoint places in HBM to avoid these parallel writes.\n",
279-
"On the other hand, parallelizing the computation is how we can execute\n",
280-
"operations like matrix multiplications really quickly.\n",
281-
"\n",
282-
"On TPUs, programs are executed in a combination of parallel and sequential\n",
283-
"(depending on the architecture) so there are slightly different considerations.\n",
284-
"\n",
281+
"TPUs distinguish between vector and scalar memory spaces and in this case the\n",
282+
"output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n",
283+
"a scalar. For more details read {ref}`tpu_and_its_memory_spaces`.\n",
285284
"To call the above kernel on TPU, run:"
286285
]
287286
},
@@ -292,6 +291,7 @@
292291
"metadata": {},
293292
"outputs": [],
294293
"source": [
294+
"# TPU version\n",
295295
"from jax.experimental.pallas import tpu as pltpu\n",
296296
"\n",
297297
"def iota(size: int):\n",
@@ -307,11 +307,22 @@
307307
"id": "68f97b4e",
308308
"metadata": {},
309309
"source": [
310-
"TPUs distinguish between vector and scalar memory spaces and in this case the\n",
311-
"output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n",
312-
"a scalar. For more details read {ref}`pallas_tpu_pipelining`.\n",
310+
"### Grid semantics\n",
311+
"\n",
312+
"On GPUs, each program is executed in parallel on separate threads.\n",
313+
"Thus, we need to think about race conditions on writes to HBM.\n",
314+
"A reasonable approach is to write our kernels in such a way that different\n",
315+
"programs write to disjoint locations in HBM to avoid these parallel writes.\n",
316+
"On the other hand, parallelizing the computation is how we can execute\n",
317+
"operations like matrix multiplications really quickly.\n",
318+
"\n",
319+
"In contrast, TPUs operate like a very wide SIMD machine.\n",
320+
"Some TPU models contain multiple cores, but in many cases a TPU can be\n",
321+
"treated as a single-threaded processor. The grid on a TPU can be\n",
322+
"specified in a combination of parallel and sequential dimensions, where sequential\n",
323+
"dimensions are guaranteed to run serially.\n",
313324
"\n",
314-
"You can read more details at {ref}`pallas_grid`."
325+
"You can read more details at {ref}`pallas_grid` and {ref}`pallas_tpu_noteworthy_properties`."
315326
]
316327
},
317328
{

docs/pallas/quickstart.md

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ def add_vectors_kernel(x_ref, y_ref, o_ref):
5353

5454
Let's dissect this function a bit. Unlike most JAX functions you've probably written,
5555
it does not take in `jax.Array`s as inputs and doesn't return any values.
56-
Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs
57-
but we are given an `o_ref`, which corresponds to the desired output.
56+
Instead, it takes in *`Ref`* objects as inputs, which represent mutable buffers in memory.
57+
Note that we also don't have any outputs but we are given an `o_ref`, which corresponds
58+
to the desired output.
5859

5960
**Reading from `Ref`s**
6061

@@ -101,7 +102,8 @@ thereof).
101102
**What's actually happening here?**
102103

103104
Thus far we've described how to think about Pallas kernels but what we've actually
104-
accomplished is we're writing a function that's executed very close to the compute units.
105+
accomplished is we're writing a function that's executed very close to the compute units
106+
since values are loaded into the innermost (fastest) portion of the memory hierarchy.
105107

106108
On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when
107109
we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM)
@@ -134,6 +136,8 @@ Part of writing Pallas kernels is thinking about how to take big arrays that
134136
live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations
135137
that operate on "blocks" of those arrays that can fit in SRAM.
136138

139+
(grids_by_example)=
140+
137141
### Grids by example
138142

139143
To automatically "carve" up the inputs and outputs, you provide a `grid` and
@@ -169,28 +173,24 @@ def iota_kernel(o_ref):
169173
```
170174

171175
We now execute it using `pallas_call` with an additional `grid` argument.
176+
On GPUs, we can call the kernel directly like so:
172177

173178
```{code-cell} ipython3
179+
# GPU version
174180
def iota(size: int):
175181
return pl.pallas_call(iota_kernel,
176182
out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
177183
grid=(size,))()
178184
iota(8)
179185
```
180186

181-
On GPUs, each program is executed in parallel on separate threads.
182-
Thus, we need to think about race conditions on writes to HBM.
183-
A reasonable approach is to write our kernels in such a way that different
184-
programs write to disjoint places in HBM to avoid these parallel writes.
185-
On the other hand, parallelizing the computation is how we can execute
186-
operations like matrix multiplications really quickly.
187-
188-
On TPUs, programs are executed in a combination of parallel and sequential
189-
(depending on the architecture) so there are slightly different considerations.
190-
187+
TPUs distinguish between vector and scalar memory spaces and in this case the
188+
output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is
189+
a scalar. For more details read {ref}`tpu_and_its_memory_spaces`.
191190
To call the above kernel on TPU, run:
192191

193192
```{code-cell} ipython3
193+
# TPU version
194194
from jax.experimental.pallas import tpu as pltpu
195195
196196
def iota(size: int):
@@ -201,11 +201,22 @@ def iota(size: int):
201201
iota(8)
202202
```
203203

204-
TPUs distinguish between vector and scalar memory spaces and in this case the
205-
output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is
206-
a scalar. For more details read {ref}`pallas_tpu_pipelining`.
204+
### Grid semantics
205+
206+
On GPUs, each program is executed in parallel on separate threads.
207+
Thus, we need to think about race conditions on writes to HBM.
208+
A reasonable approach is to write our kernels in such a way that different
209+
programs write to disjoint locations in HBM to avoid these parallel writes.
210+
On the other hand, parallelizing the computation is how we can execute
211+
operations like matrix multiplications really quickly.
212+
213+
In contrast, TPUs operate like a very wide SIMD machine.
214+
Some TPU models contain multiple cores, but in many cases a TPU can be
215+
treated as a single-threaded processor. The grid on a TPU can be
216+
specified in a combination of parallel and sequential dimensions, where sequential
217+
dimensions are guaranteed to run serially.
207218

208-
You can read more details at {ref}`pallas_grid`.
219+
You can read more details at {ref}`pallas_grid` and {ref}`pallas_tpu_noteworthy_properties`.
209220

210221
+++
211222

docs/pallas/tpu/pipelining.ipynb

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,20 @@
4848
},
4949
{
5050
"cell_type": "markdown",
51+
"id": "0e212a5e",
5152
"metadata": {
5253
"id": "TWKESTKAlyjT"
5354
},
5455
"source": [
55-
"## TPU and its memory spaces\n",
56+
"(tpu_and_its_memory_spaces)=\n",
5657
"\n",
58+
"## TPU and its memory spaces"
59+
]
60+
},
61+
{
62+
"cell_type": "markdown",
63+
"metadata": {},
64+
"source": [
5765
"A TPU and its TensorCore consist of memory spaces (where arrays can reside),\n",
5866
"registers (which temporarily store scalar and array values) and compute units\n",
5967
"(that do computation with values in registers).\n",

docs/pallas/tpu/pipelining.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,12 @@ import numpy as np
3838

3939
+++ {"id": "TWKESTKAlyjT"}
4040

41+
(tpu_and_its_memory_spaces)=
42+
4143
## TPU and its memory spaces
4244

45+
+++
46+
4347
A TPU and its TensorCore consist of memory spaces (where arrays can reside),
4448
registers (which temporarily store scalar and array values) and compute units
4549
(that do computation with values in registers).

0 commit comments

Comments
 (0)