Skip to content

Commit 978d35f

Browse files
dfmGoogle-ML-Automation
authored andcommitted
Fix expected exception type in pallas grad tests.
PiperOrigin-RevId: 704408603
1 parent 71c48cb commit 978d35f

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

tests/pallas/pallas_test.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,11 +1340,7 @@ def if_true(z):
13401340
np.testing.assert_allclose(f(jnp.bool_(False), arg),
13411341
-arg)
13421342

1343-
# We actually expect the assertion failure in linearize, but this also
1344-
# covers another case where an effect was causing an earlier assertion
1345-
# failure.
1346-
with self.assertRaises(AssertionError):
1347-
# Notably, we should not have a ValueError for mismatched Read<N> effect.
1343+
with self.assertRaisesRegex(ValueError, "Linearization failed"):
13481344
_ = jax.grad(lambda x: jnp.sum(f(jnp.bool_(True), x)**2))(arg)
13491345
# np.testing.assert_allclose(
13501346
# dx, jnp.float32([0., 2, 4, 6, 0, 10, 12 + 12, 14]))
@@ -1397,7 +1393,7 @@ def body_fn(i, args):
13971393
16 * x * params[4, 2])
13981394
np.testing.assert_allclose(f(program, params, x), expected)
13991395

1400-
with self.assertRaises(AssertionError):
1396+
with self.assertRaisesRegex(ValueError, "Linearization failed"):
14011397
jax.value_and_grad(lambda params, x: f(program, params, x).sum())(
14021398
params, x)
14031399

@@ -1451,7 +1447,7 @@ def body_fn(i, args):
14511447
16 * x * params[4, 2])
14521448
np.testing.assert_allclose(f(program, params, x), expected)
14531449

1454-
with self.assertRaises(AssertionError):
1450+
with self.assertRaisesRegex(ValueError, "Linearization failed"):
14551451
jax.value_and_grad(lambda params, x: f(program, params, x).sum())(
14561452
params, x)
14571453

0 commit comments

Comments
 (0)