@@ -1096,7 +1096,7 @@ def test_weight_norm(getkey):
10961096 out_weight_norm = weight_norm_linear (x )
10971097 out_linear = linear (x )
10981098
1099- assert jnp .allclose (out_weight_norm , out_linear )
1099+ assert jnp .allclose (out_weight_norm , out_linear , atol = 1e-4 , rtol = 1e-4 )
11001100
11011101 # Axis == None
11021102 linear = eqx .nn .Linear (4 , 4 , key = getkey ())
@@ -1108,7 +1108,7 @@ def test_weight_norm(getkey):
11081108 out_weight_norm = weight_norm_linear (x )
11091109 out_linear = linear (x )
11101110
1111- assert jnp .allclose (out_weight_norm , out_linear )
1111+ assert jnp .allclose (out_weight_norm , out_linear , atol = 1e-4 , rtol = 1e-4 )
11121112
11131113 # Conv3d (ndim weight matrices > 2)
11141114 conv = eqx .nn .Conv3d (2 , 3 , 3 , key = getkey ())
@@ -1117,7 +1117,7 @@ def test_weight_norm(getkey):
11171117 out_weight_norm = weight_norm_conv (x )
11181118 out_conv = conv (x )
11191119
1120- assert jnp .allclose (out_weight_norm , out_conv )
1120+ assert jnp .allclose (out_weight_norm , out_conv , atol = 1e-4 , rtol = 1e-4 )
11211121
11221122 # Grads get generated for reparametrized weights, not original
11231123 grads = eqx .filter_grad (lambda model , x : jnp .mean (model (x )))(
0 commit comments