Skip to content

Commit a83c167

Browse files
Adjust tiling in layout_test for TPU v7.
The entry tiling need to match with tpu_comp_env PiperOrigin-RevId: 834775604
1 parent 9ebf41a commit a83c167

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tests/layout_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,9 @@ def test_layout_donation_mismatching_in_and_out_fails(self):
579579
shape = (16*2, 32016*2)
580580
np_inp = np.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape)
581581

582-
custom_dll1 = Layout(major_to_minor=(1, 0), tiling=((8,128), (2,1)))
582+
tiling = (((16, 128), (2, 1)) if jtu.get_tpu_version() == 7
583+
else ((8, 128), (2, 1)))
584+
custom_dll1 = Layout(major_to_minor=(1, 0), tiling=tiling)
583585
l1 = Format(custom_dll1, s)
584586
arr = jax.device_put(np_inp, s)
585587

0 commit comments

Comments
 (0)