Skip to content

Commit 4f70471

Browse files
Fix error in pallas tutorial
PiperOrigin-RevId: 737727935
1 parent 20658fa commit 4f70471

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

docs/pallas/tpu/sparse.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@
299299
" ):\n",
300300
" \"\"\"A DSD (Dense = Sparse @ Dense) matmul kernel.\"\"\"\n",
301301
" del idxs_k_ref\n",
302-
" blk_idx = pl.program_id(0)\n",
302+
" blk_idx = pl.program_id(1)\n",
303303
" is_start = blk_idx == 0\n",
304304
" changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])\n",
305305
" @pl.when(is_start | changed_blocks)\n",
@@ -314,13 +314,13 @@
314314
" o_ref[...] = accum_scratch[...].astype(o_ref.dtype)\n",
315315
"\n",
316316
"\n",
317-
"def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
317+
"def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
318318
" del j, blk_idxs_i, blk_idxs_k\n",
319319
" return (blk_idx, 0, 0)\n",
320-
"def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
320+
"def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
321321
" del blk_idxs_i\n",
322322
" return (blk_idxs_k[blk_idx], j)\n",
323-
"def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
323+
"def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n",
324324
" del blk_idxs_k\n",
325325
" return (blk_idxs_i[blk_idx], j)\n",
326326
"\n",
@@ -335,7 +335,7 @@
335335
" num_scalar_prefetch=2,\n",
336336
" # Note that while num_blocks is static here, Pallas does support\n",
337337
" # dynamic grid sizes.\n",
338-
" grid=(num_blocks, N // blk_N),\n",
338+
" grid=(N // blk_N, num_blocks),\n",
339339
" in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),\n",
340340
" pl.BlockSpec((blk_K, blk_N), y_map),\n",
341341
" # Placeholder for a zeros-array used by input_output_aliases.\n",

docs/pallas/tpu/sparse.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
239239
):
240240
"""A DSD (Dense = Sparse @ Dense) matmul kernel."""
241241
del idxs_k_ref
242-
blk_idx = pl.program_id(0)
242+
blk_idx = pl.program_id(1)
243243
is_start = blk_idx == 0
244244
changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])
245245
@pl.when(is_start | changed_blocks)
@@ -254,13 +254,13 @@ def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
254254
o_ref[...] = accum_scratch[...].astype(o_ref.dtype)
255255
256256
257-
def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
257+
def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
258258
del j, blk_idxs_i, blk_idxs_k
259259
return (blk_idx, 0, 0)
260-
def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
260+
def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
261261
del blk_idxs_i
262262
return (blk_idxs_k[blk_idx], j)
263-
def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
263+
def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
264264
del blk_idxs_k
265265
return (blk_idxs_i[blk_idx], j)
266266
@@ -275,7 +275,7 @@ grid_spec = pltpu.PrefetchScalarGridSpec(
275275
num_scalar_prefetch=2,
276276
# Note that while num_blocks is static here, Pallas does support
277277
# dynamic grid sizes.
278-
grid=(num_blocks, N // blk_N),
278+
grid=(N // blk_N, num_blocks),
279279
in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),
280280
pl.BlockSpec((blk_K, blk_N), y_map),
281281
# Placeholder for a zeros-array used by input_output_aliases.

0 commit comments

Comments
 (0)