Skip to content

Commit 250e9fe

Browse files
Fix test flake
1 parent f86db57 commit 250e9fe

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/test_nn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,7 +1096,7 @@ def test_weight_norm(getkey):
10961096
out_weight_norm = weight_norm_linear(x)
10971097
out_linear = linear(x)
10981098

1099-
assert jnp.allclose(out_weight_norm, out_linear)
1099+
assert jnp.allclose(out_weight_norm, out_linear, atol=1e-4, rtol=1e-4)
11001100

11011101
# Axis == None
11021102
linear = eqx.nn.Linear(4, 4, key=getkey())
@@ -1108,7 +1108,7 @@ def test_weight_norm(getkey):
11081108
out_weight_norm = weight_norm_linear(x)
11091109
out_linear = linear(x)
11101110

1111-
assert jnp.allclose(out_weight_norm, out_linear)
1111+
assert jnp.allclose(out_weight_norm, out_linear, atol=1e-4, rtol=1e-4)
11121112

11131113
# Conv3d (ndim weight matrices > 2)
11141114
conv = eqx.nn.Conv3d(2, 3, 3, key=getkey())
@@ -1117,7 +1117,7 @@ def test_weight_norm(getkey):
11171117
out_weight_norm = weight_norm_conv(x)
11181118
out_conv = conv(x)
11191119

1120-
assert jnp.allclose(out_weight_norm, out_conv)
1120+
assert jnp.allclose(out_weight_norm, out_conv, atol=1e-4, rtol=1e-4)
11211121

11221122
# Grads get generated for reparametrized weights, not original
11231123
grads = eqx.filter_grad(lambda model, x: jnp.mean(model(x)))(

0 commit comments

Comments
 (0)