|
72 | 72 | "\n", |
73 | 73 | "Let's dissect this function a bit. Unlike most JAX functions you've probably written,\n", |
74 | 74 | "it does not take in `jax.Array`s as inputs and doesn't return any values.\n", |
75 | | - "Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs\n", |
76 | | - "but we are given an `o_ref`, which corresponds to the desired output.\n", |
| 75 | + "Instead, it takes in *`Ref`* objects as inputs, which represent mutable buffers in memory.\n", |
| 76 | + "Note that we also don't have any outputs but we are given an `o_ref`, which corresponds\n", |
| 77 | + "to the desired output.\n", |
77 | 78 | "\n", |
78 | 79 | "**Reading from `Ref`s**\n", |
79 | 80 | "\n", |
|
150 | 151 | "**What's actually happening here?**\n", |
151 | 152 | "\n", |
152 | 153 | "Thus far we've described how to think about Pallas kernels but what we've actually\n", |
153 | | - "accomplished is we're writing a function that's executed very close to the compute units.\n", |
| 154 | + "accomplished is we're writing a function that's executed very close to the compute units\n", |
| 155 | + "since values are loaded into the innermost (fastest) portion of the memory hierarchy.\n", |
154 | 156 | "\n", |
155 | 157 | "On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when\n", |
156 | 158 | "we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM)\n", |
|
195 | 197 | "live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations\n", |
196 | 198 | "that operate on \"blocks\" of those arrays that can fit in SRAM.\n", |
197 | 199 | "\n", |
| 200 | + "(grids_by_example)=\n", |
| 201 | + "\n", |
198 | 202 | "### Grids by example\n", |
199 | 203 | "\n", |
200 | 204 | "To automatically \"carve\" up the inputs and outputs, you provide a `grid` and\n", |
|
240 | 244 | "cell_type": "markdown", |
241 | 245 | "metadata": {}, |
242 | 246 | "source": [ |
243 | | - "We now execute it using `pallas_call` with an additional `grid` argument." |
| 247 | + "We now execute it using `pallas_call` with an additional `grid` argument.\n", |
| 248 | + "On GPUs, we can call the kernel directly like so:" |
244 | 249 | ] |
245 | 250 | }, |
246 | 251 | { |
|
260 | 265 | } |
261 | 266 | ], |
262 | 267 | "source": [ |
| 268 | + "# GPU version\n", |
263 | 269 | "def iota(size: int):\n", |
264 | 270 | " return pl.pallas_call(iota_kernel,\n", |
265 | 271 | " out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n", |
|
272 | 278 | "cell_type": "markdown", |
273 | 279 | "metadata": {}, |
274 | 280 | "source": [ |
275 | | - "On GPUs, each program is executed in parallel on separate threads.\n", |
276 | | - "Thus, we need to think about race conditions on writes to HBM.\n", |
277 | | - "A reasonable approach is to write our kernels in such a way that different\n", |
278 | | - "programs write to disjoint places in HBM to avoid these parallel writes.\n", |
279 | | - "On the other hand, parallelizing the computation is how we can execute\n", |
280 | | - "operations like matrix multiplications really quickly.\n", |
281 | | - "\n", |
282 | | - "On TPUs, programs are executed in a combination of parallel and sequential\n", |
283 | | - "(depending on the architecture) so there are slightly different considerations.\n", |
284 | | - "\n", |
| 281 | + "TPUs distinguish between vector and scalar memory spaces and in this case the\n", |
| 282 | + "output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n", |
| 283 | + "a scalar. For more details read {ref}`tpu_and_its_memory_spaces`.\n", |
285 | 284 | "To call the above kernel on TPU, run:" |
286 | 285 | ] |
287 | 286 | }, |
|
292 | 291 | "metadata": {}, |
293 | 292 | "outputs": [], |
294 | 293 | "source": [ |
| 294 | + "# TPU version\n", |
295 | 295 | "from jax.experimental.pallas import tpu as pltpu\n", |
296 | 296 | "\n", |
297 | 297 | "def iota(size: int):\n", |
|
307 | 307 | "id": "68f97b4e", |
308 | 308 | "metadata": {}, |
309 | 309 | "source": [ |
310 | | - "TPUs distinguish between vector and scalar memory spaces and in this case the\n", |
311 | | - "output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n", |
312 | | - "a scalar. For more details read {ref}`pallas_tpu_pipelining`.\n", |
| 310 | + "### Grid semantics\n", |
| 311 | + "\n", |
| 312 | + "On GPUs, each program is executed in parallel on separate threads.\n", |
| 313 | + "Thus, we need to think about race conditions on writes to HBM.\n", |
| 314 | + "A reasonable approach is to write our kernels in such a way that different\n", |
| 315 | + "programs write to disjoint locations in HBM to avoid these parallel writes.\n", |
| 316 | + "On the other hand, parallelizing the computation is how we can execute\n", |
| 317 | + "operations like matrix multiplications really quickly.\n", |
| 318 | + "\n", |
| 319 | + "In contrast, TPUs operate like a very wide SIMD machine.\n", |
| 320 | + "Some TPU models contain multiple cores, but in many cases a TPU can be\n", |
| 321 | + "treated as a single-threaded processor. The grid on a TPU can be\n", |
| 322 | + "specified in a combination of parallel and sequential dimensions, where sequential\n", |
| 323 | + "dimensions are guaranteed to run serially.\n", |
313 | 324 | "\n", |
314 | | - "You can read more details at {ref}`pallas_grid`." |
| 325 | + "You can read more details at {ref}`pallas_grid` and {ref}`pallas_tpu_noteworthy_properties`." |
315 | 326 | ] |
316 | 327 | }, |
317 | 328 | { |
|
0 commit comments