Skip to content

Commit 1222b4a

Browse files
pajarskasGoogle-ML-Automation
authored andcommitted
[Pallas TPU] Add a better error message for rank 1 block mappings check.
Currently, the error message refers to "last two dimensions" which is confusing for a rank-1 case; furthermore, the error does not match the check in the code. PiperOrigin-RevId: 686520781
1 parent 4c0d828 commit 1222b4a

File tree

2 files changed

+31
-12
lines changed

2 files changed

+31
-12
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -498,20 +498,30 @@ def err_details():
498498
(bs0 == as0 or bs0 % 128 == 0) and
499499
(bs1 == as1 or bs1 % 8 == 0)
500500
)
501+
if not evenly_divisible:
502+
raise ValueError(
503+
"The Pallas TPU lowering currently requires that the last two "
504+
"dimensions of your block shape are divisible by 8 and 128 "
505+
"respectively, or be equal to the respective dimensions of the "
506+
"overall array. "
507+
+ err_details()
508+
)
501509
else:
502510
assert rank == 1
503511
# TODO(necula): test this for bool. What should it do?
504512
tiling_size = 128 * (32 // lax_internal._bit_width(bm.array_shape_dtype.dtype))
505513
evenly_divisible = (bs0 == as0 or bs0 % tiling_size == 0)
506-
507-
if not evenly_divisible:
508-
raise ValueError(
509-
"The Pallas TPU lowering currently requires that the last two "
510-
"dimensions of your block shape are divisible by 8 and 128 "
511-
"respectively, or be equal to the respective dimensions of the "
512-
"overall array. "
513-
+ err_details()
514-
)
514+
if not evenly_divisible:
515+
raise ValueError(
516+
"The Pallas TPU lowering currently requires that rank 1 block"
517+
" shapes, either 1) the first (and only) dimension of the block"
518+
" shape is equal to the first (and only) dimension of the array"
519+
" shape, or 2) the first (and only) dimension of the block shape"
520+
f" is a multiple of the tiling size ({tiling_size} = 128 * (32 //"
521+
f" {lax_internal._bit_width(bm.array_shape_dtype.dtype)})) of the"
522+
" array shape. "
523+
+ err_details()
524+
)
515525

516526

517527
def lower_jaxpr_to_module(

tests/pallas/pallas_test.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,9 +386,18 @@ def copy_kernel(x_ref, o_ref):
386386
(bs0 == as0 or bs0 % 128 == 0) and
387387
(bs1 == as1 or bs1 % 8 == 0))
388388
if not evenly_divisible:
389-
test_context = self.assertRaisesRegex(
390-
ValueError,
391-
"last two dimensions of your block shape are divisible by 8 and 128")
389+
if rank == 1:
390+
test_context = self.assertRaisesRegex(
391+
ValueError,
392+
r"the first \(and only\) dimension of the block shape is a"
393+
" multiple of the tiling size",
394+
)
395+
else:
396+
test_context = self.assertRaisesRegex(
397+
ValueError,
398+
"last two dimensions of your block shape are divisible by 8"
399+
" and 128",
400+
)
392401

393402
elif jtu.test_device_matches(["gpu"]) and not self.INTERPRET:
394403
block_size = math.prod(block_shape)

0 commit comments

Comments
 (0)