File tree Expand file tree Collapse file tree 2 files changed +31
-12
lines changed Expand file tree Collapse file tree 2 files changed +31
-12
lines changed Original file line number Diff line number Diff line change @@ -498,20 +498,30 @@ def err_details():
498498 (bs0 == as0 or bs0 % 128 == 0 ) and
499499 (bs1 == as1 or bs1 % 8 == 0 )
500500 )
501+ if not evenly_divisible :
502+ raise ValueError (
503+ "The Pallas TPU lowering currently requires that the last two "
504+ "dimensions of your block shape are divisible by 8 and 128 "
505+ "respectively, or be equal to the respective dimensions of the "
506+ "overall array. "
507+ + err_details ()
508+ )
501509 else :
502510 assert rank == 1
503511 # TODO(necula): test this for bool. What should it do?
504512 tiling_size = 128 * (32 // lax_internal ._bit_width (bm .array_shape_dtype .dtype ))
505513 evenly_divisible = (bs0 == as0 or bs0 % tiling_size == 0 )
506-
507- if not evenly_divisible :
508- raise ValueError (
509- "The Pallas TPU lowering currently requires that the last two "
510- "dimensions of your block shape are divisible by 8 and 128 "
511- "respectively, or be equal to the respective dimensions of the "
512- "overall array. "
513- + err_details ()
514- )
514+ if not evenly_divisible :
515+ raise ValueError (
516+ "The Pallas TPU lowering currently requires that rank 1 block"
517+ " shapes, either 1) the first (and only) dimension of the block"
518+ " shape is equal to the first (and only) dimension of the array"
519+ " shape, or 2) the first (and only) dimension of the block shape"
520+ f" is a multiple of the tiling size ({ tiling_size } = 128 * (32 //"
521+ f" { lax_internal ._bit_width (bm .array_shape_dtype .dtype )} )) of the"
522+ " array shape. "
523+ + err_details ()
524+ )
515525
516526
517527def lower_jaxpr_to_module (
Original file line number Diff line number Diff line change @@ -386,9 +386,18 @@ def copy_kernel(x_ref, o_ref):
386386 (bs0 == as0 or bs0 % 128 == 0 ) and
387387 (bs1 == as1 or bs1 % 8 == 0 ))
388388 if not evenly_divisible :
389- test_context = self .assertRaisesRegex (
390- ValueError ,
391- "last two dimensions of your block shape are divisible by 8 and 128" )
389+ if rank == 1 :
390+ test_context = self .assertRaisesRegex (
391+ ValueError ,
392+ r"the first \(and only\) dimension of the block shape is a"
393+ " multiple of the tiling size" ,
394+ )
395+ else :
396+ test_context = self .assertRaisesRegex (
397+ ValueError ,
398+ "last two dimensions of your block shape are divisible by 8"
399+ " and 128" ,
400+ )
392401
393402 elif jtu .test_device_matches (["gpu" ]) and not self .INTERPRET :
394403 block_size = math .prod (block_shape )
You can’t perform that action at this time.
0 commit comments