diff --git a/tests/nnx/nn/attention_test.py b/tests/nnx/nn/attention_test.py index bbc48847a..9ce35f058 100644 --- a/tests/nnx/nn/attention_test.py +++ b/tests/nnx/nn/attention_test.py @@ -206,7 +206,7 @@ class DummyModule(nnx.Module): np.testing.assert_allclose(attn_jax, attn_manual, atol=1e-6) -# TODO: add all possible constructor argument values to parameterized.product + class TestLinenConsistency(parameterized.TestCase): @parameterized.product( use_bias=[True, False], @@ -215,6 +215,8 @@ class TestLinenConsistency(parameterized.TestCase): precision=[Precision.DEFAULT, Precision.HIGH, Precision.HIGHEST], decode=[True, False], normalize_qk=[True, False], + qkv_features=[None, 8], + out_features=[None, 6], ) def test_nnx_attention_equivalence( self, @@ -224,16 +226,16 @@ def test_nnx_attention_equivalence( precision: PrecisionLike, decode: bool, normalize_qk: bool, + qkv_features: tp.Optional[int], + out_features: tp.Optional[int], ): key = jax.random.key(42) rngs = nnx.Rngs(42) num_heads = 2 - in_features = 3 - qkv_features = 6 - out_features = 6 + in_features = 4 - x = jax.numpy.ones((1, in_features)) + x = jnp.ones((1, in_features)) model_nnx = nnx.MultiHeadAttention( num_heads=num_heads, in_features=in_features, @@ -264,12 +266,37 @@ def test_nnx_attention_equivalence( getattr(model_nnx, qkvo).kernel[...] = variables['params'][qkvo]['kernel'] if use_bias: getattr(model_nnx, qkvo).bias[...] = variables['params'][qkvo]['bias'] + if normalize_qk: + model_nnx.query_ln.scale[...] = variables['params']['query_ln']['scale'] + model_nnx.key_ln.scale[...] = variables['params']['key_ln']['scale'] + + # Guard: verify params were copied correctly + for name in ('query', 'key', 'value', 'out'): + np.testing.assert_array_equal( + variables['params'][name]['kernel'], + getattr(model_nnx, name).kernel[...], + ) + if use_bias: + np.testing.assert_array_equal( + variables['params'][name]['bias'], + getattr(model_nnx, name).bias[...], + ) + if normalize_qk: + np.testing.assert_array_equal( + variables['params']['query_ln']['scale'], + model_nnx.query_ln.scale[...], + ) + np.testing.assert_array_equal( + variables['params']['key_ln']['scale'], + model_nnx.key_ln.scale[...], + ) if decode: model_nnx.init_cache(x.shape, dtype=dtype) out_nnx = model_nnx(x) - out, cache = model.apply(variables, x, mutable=['cache']) - np.testing.assert_array_equal(out, out_nnx) + out, _ = model.apply(variables, x, mutable=['cache']) + rtol = 1e-3 if dtype == jnp.float16 or param_dtype == jnp.float16 else 1e-6 + np.testing.assert_allclose(out, out_nnx, rtol=rtol) class TestKVFeatures(parameterized.TestCase): @@ -284,7 +311,7 @@ def test_varying_num_features(self): qkv_features = 6 out_features = 6 - x = jax.numpy.ones((1, in_features)) + x = jnp.ones((1, in_features)) y = jax.random.normal(key, (1, in_kv_features)) layer = nnx.MultiHeadAttention( num_heads=num_heads, @@ -354,5 +381,3 @@ class DummyModule(nnx.Module): if __name__ == '__main__': absltest.main() - -