Skip to content

Commit 66b9005

Browse files
kanglantGoogle-ML-Automation
authored andcommitted
Disable pjit ArrayPjitTest.test_device_put_grad test on TPU v5e
PiperOrigin-RevId: 704378732
1 parent 1c07ec6 commit 66b9005

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

tests/pjit_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3411,6 +3411,9 @@ def test_device_assignment_mismatch_apply_primitive(self):
34113411
def test_device_put_grad(self):
34123412
if jax.device_count() < 8:
34133413
self.skipTest("Requires >=8 devices.")
3414+
if jtu.is_device_tpu(5, 'e'):
3415+
self.skipTest('TPU v5e does not support computations that run on a '
3416+
'non-singleton subset of cores.')
34143417

34153418
def _test(fun, inp, np_inp, in_s):
34163419
out = fun(inp)

0 commit comments

Comments
 (0)