@@ -1340,11 +1340,7 @@ def if_true(z):
13401340 np .testing .assert_allclose (f (jnp .bool_ (False ), arg ),
13411341 - arg )
13421342
1343- # We actually expect the assertion failure in linearize, but this also
1344- # covers another case where an effect was causing an earlier assertion
1345- # failure.
1346- with self .assertRaises (AssertionError ):
1347- # Notably, we should not have a ValueError for mismatched Read<N> effect.
1343+ with self .assertRaisesRegex (ValueError , "Linearization failed" ):
13481344 _ = jax .grad (lambda x : jnp .sum (f (jnp .bool_ (True ), x )** 2 ))(arg )
13491345 # np.testing.assert_allclose(
13501346 # dx, jnp.float32([0., 2, 4, 6, 0, 10, 12 + 12, 14]))
@@ -1397,7 +1393,7 @@ def body_fn(i, args):
13971393 16 * x * params [4 , 2 ])
13981394 np .testing .assert_allclose (f (program , params , x ), expected )
13991395
1400- with self .assertRaises ( AssertionError ):
1396+ with self .assertRaisesRegex ( ValueError , "Linearization failed" ):
14011397 jax .value_and_grad (lambda params , x : f (program , params , x ).sum ())(
14021398 params , x )
14031399
@@ -1451,7 +1447,7 @@ def body_fn(i, args):
14511447 16 * x * params [4 , 2 ])
14521448 np .testing .assert_allclose (f (program , params , x ), expected )
14531449
1454- with self .assertRaises ( AssertionError ):
1450+ with self .assertRaisesRegex ( ValueError , "Linearization failed" ):
14551451 jax .value_and_grad (lambda params , x : f (program , params , x ).sum ())(
14561452 params , x )
14571453
0 commit comments