diff --git a/tests/nnx/nn/linear_test.py b/tests/nnx/nn/linear_test.py index 4558b1516..1db04e433 100644 --- a/tests/nnx/nn/linear_test.py +++ b/tests/nnx/nn/linear_test.py @@ -240,6 +240,281 @@ def linen_loss_function(variables): ) +class TestEinsum(parameterized.TestCase): + + @parameterized.product( + einsum_str_input_kernel_bias_output=[ + # matrix multiply: (batch, in) x (in, out) -> (batch, out) + ('bi,io->bo', (4, 3), (3, 5), (5,), (4, 5)), + # batched matmul with extra dims + ('nta,hab->nthb', (16, 11, 2), (8, 2, 4), (8, 4), (16, 11, 8, 4)), + # multi-head attention-like: (batch, seq, heads, dim) + ('bshd,hdo->bsho', (2, 8, 4, 16), (4, 16, 32), (4, 32), + (2, 8, 4, 32)), + # contraction over multiple dims + ('defab,bcef->adefc', (8, 6, 7, 3, 4), (4, 5, 6, 7), (6, 7, 5), + (3, 8, 6, 7, 5)), + # simple 2D without bias + ('ij,jk->ik', (2, 3), (3, 4), None, (2, 4)), + ], + ) + def test_output_shape(self, einsum_str_input_kernel_bias_output): + einsum_str, input_shape, kernel_shape, bias_shape, expected_shape = ( + einsum_str_input_kernel_bias_output + ) + model = nnx.Einsum( + einsum_str, kernel_shape, bias_shape, rngs=nnx.Rngs(0) + ) + x = jnp.ones(input_shape) + y = model(x) + self.assertEqual(y.shape, expected_shape) + self.assertEqual(model.kernel.shape, kernel_shape) + if bias_shape is not None: + self.assertIsNotNone(model.bias) + self.assertEqual(model.bias.shape, bias_shape) + else: + self.assertIsNone(model.bias) + + @parameterized.product( + dtype=[None, jnp.float32, jnp.float16], + param_dtype=[jnp.float32, jnp.float16], + preferred_element_type=[None, jnp.float32], + ) + def test_dtypes(self, dtype, param_dtype, preferred_element_type): + model = nnx.Einsum( + 'bi,io->bo', + (3, 5), + (5,), + dtype=dtype, + param_dtype=param_dtype, + preferred_element_type=preferred_element_type, + rngs=nnx.Rngs(0), + ) + self.assertEqual(model.kernel.dtype, param_dtype) + self.assertEqual(model.bias.dtype, param_dtype) + + x = jnp.ones((2, 3)) + y = model(x) + if preferred_element_type is not None: + self.assertEqual(y.dtype, preferred_element_type) + elif dtype is not None: + self.assertEqual(y.dtype, dtype) + else: + # dtype=None: output dtype inferred from input (float32) and params + expected_dtype = jnp.result_type(jnp.float32, param_dtype) + self.assertEqual(y.dtype, expected_dtype) + + @parameterized.product( + dtype=[jnp.float32, jnp.float16], + param_dtype=[jnp.float32, jnp.float16], + ) + def test_no_bias(self, dtype, param_dtype): + model = nnx.Einsum( + 'bi,io->bo', (3, 5), bias_shape=None, + dtype=dtype, param_dtype=param_dtype, + kernel_init=nnx.initializers.ones, rngs=nnx.Rngs(0), + ) + x = jnp.ones((2, 3), dtype=dtype) + y = model(x) + + # kernel=ones: y[b,o] = sum_i(x[b,i] * 1) = 3.0 for all b,o + rtol = ( + 1e-3 + if dtype == jnp.float16 or param_dtype == jnp.float16 + else 1e-6 + ) + np.testing.assert_allclose( + y, jnp.full((2, 5), 3.0, dtype=y.dtype), rtol=rtol, + ) + + def test_bias_addition_multidim(self): + model = nnx.Einsum( + 'defab,bcef->adefc', + (4, 5, 6, 7), + (6, 7, 5), + kernel_init=nnx.initializers.zeros, + bias_init=nnx.initializers.ones, + rngs=nnx.Rngs(0), + ) + x = jnp.zeros((8, 6, 7, 3, 4)) + # Verify _infer_broadcasted_bias_shape computes correct shape + # output='adefc': bias dims (e,f,c) from rhs -> [1, 1, 6, 7, 5] + broadcasted_shape = model._infer_broadcasted_bias_shape( + 'defab,bcef->adefc', x, model.kernel[...] + ) + self.assertEqual(broadcasted_shape, [1, 1, 6, 7, 5]) + # kernel=0 means einsum result is 0, output should be broadcast bias + y = model(x) + self.assertEqual(y.shape, (3, 8, 6, 7, 5)) + np.testing.assert_allclose(y, jnp.ones_like(y), rtol=1e-6) + + def test_bias_addition_simple(self): + model = nnx.Einsum( + 'bi,io->bo', (3, 5), (5,), rngs=nnx.Rngs(0) + ) + x = jnp.ones((4, 3)) + y_with_bias = model(x) + + no_bias_result = jnp.einsum('bi,io->bo', x, model.kernel[...]) + diff = y_with_bias - no_bias_result + expected_diff = jnp.broadcast_to( + model.bias[...], y_with_bias.shape + ) + np.testing.assert_allclose(diff, expected_diff, rtol=1e-6) + + def test_bias_broadcast_with_ellipsis(self): + model = nnx.Einsum( + 'd...ab,bc...->ad...c', + (4, 5, 6, 7), + (5, 6, 7), + kernel_init=nnx.initializers.zeros, + bias_init=nnx.initializers.ones, + rngs=nnx.Rngs(0), + ) + x = jnp.zeros((8, 6, 7, 3, 4)) + # Verify broadcasted bias shape for ellipsis equation + broadcasted_shape = model._infer_broadcasted_bias_shape( + 'd...ab,bc...->ad...c', x, model.kernel[...] + ) + self.assertEqual(broadcasted_shape, [1, 1, 6, 7, 5]) + y = model(x) + self.assertEqual(y.shape, (3, 8, 6, 7, 5)) + np.testing.assert_allclose(y, jnp.ones_like(y), rtol=1e-6) + + def test_spaces_in_einsum_str(self): + model_spaces = nnx.Einsum( + 'b i, i o -> b o', (3, 5), bias_shape=None, + kernel_init=nnx.initializers.ones, rngs=nnx.Rngs(0), + ) + model_clean = nnx.Einsum( + 'bi,io->bo', (3, 5), bias_shape=None, + kernel_init=nnx.initializers.ones, rngs=nnx.Rngs(0), + ) + x = jnp.ones((2, 3)) + np.testing.assert_array_equal( + model_spaces(x), model_clean(x) + ) + + def test_einsum_str_missing_arrow_raises(self): + with self.assertRaisesRegex( + ValueError, 'must be explicit and include' + ): + nnx.Einsum('bi,io', (3, 5), rngs=nnx.Rngs(0)) + + def test_einsum_str_wrong_operand_count_raises(self): + with self.assertRaisesRegex( + ValueError, 'exactly two operands' + ): + nnx.Einsum('a,b,c->d', (3,), rngs=nnx.Rngs(0)) + + @parameterized.product( + precision=[Precision.DEFAULT, Precision.HIGH, Precision.HIGHEST], + ) + def test_precision(self, precision): + received = [] + + def capturing_einsum( + *args, precision=None, out_sharding=None, **kwargs, + ): + received.append(precision) + return jnp.einsum(*args, precision=precision, **kwargs) + + model = nnx.Einsum( + 'bi,io->bo', (3, 5), bias_shape=None, + precision=precision, + einsum_op=capturing_einsum, + rngs=nnx.Rngs(0), + ) + model(jnp.ones((2, 3))) + self.assertLen(received, 1) + self.assertEqual(received[0], precision) + + @parameterized.product( + bias_shape=[(5,), None], + ) + def test_gradient_flow(self, bias_shape): + model = nnx.Einsum( + 'bi,io->bo', (3, 5), bias_shape, + kernel_init=nnx.initializers.ones, + bias_init=nnx.initializers.zeros, + rngs=nnx.Rngs(0), + ) + x = jnp.ones((2, 3)) + + grads = jax.grad(lambda m: m(x).sum())(model) + # d(sum(x @ kernel + bias))/d(kernel[j,k]) = sum_i(x[i,j]) = 2.0 + np.testing.assert_allclose( + grads.kernel[...], jnp.full((3, 5), 2.0), rtol=1e-6 + ) + if bias_shape is not None: + # d(sum(bias))/d(bias[k]) = batch_size = 2.0 + np.testing.assert_allclose( + grads.bias[...], jnp.full(bias_shape, 2.0), rtol=1e-6 + ) + else: + self.assertIsNone(grads.bias) + + def test_einsum_str_call_override(self): + received = [] + + def capturing_einsum( + expr, *args, out_sharding=None, **kwargs, + ): + received.append(expr) + return jnp.einsum(expr, *args, **kwargs) + + # Constructor uses 'ab,bc->ac', call overrides with 'bi,io->bo' + model = nnx.Einsum( + 'ab,bc->ac', (3, 5), bias_shape=None, + einsum_op=capturing_einsum, rngs=nnx.Rngs(0), + ) + x = jnp.ones((2, 3)) + + # Without override: constructor string is used + model(x) + self.assertLen(received, 1) + self.assertEqual(received[0], 'ab,bc->ac') + + # With override: call-time string must win + model(x, einsum_str='bi,io->bo') + self.assertLen(received, 2) + self.assertEqual(received[1], 'bi,io->bo') + + def test_custom_initializers(self): + model = nnx.Einsum( + 'bi,io->bo', + (3, 5), + (5,), + kernel_init=nnx.initializers.ones, + bias_init=nnx.initializers.zeros, + param_dtype=jnp.float32, + rngs=nnx.Rngs(0), + ) + np.testing.assert_array_equal( + model.kernel[...], jnp.ones((3, 5)) + ) + np.testing.assert_array_equal( + model.bias[...], jnp.zeros((5,)) + ) + # Verify initializers affect forward pass + x = jnp.ones((2, 3)) + y = model(x) + # kernel=ones, bias=zeros: y = x @ ones + 0 = [[3,3,3,3,3],...] + expected = jnp.full((2, 5), 3.0) + np.testing.assert_allclose(y, expected, rtol=1e-6) + + def test_ellipsis_einsum_str(self): + model = nnx.Einsum( + 'd...ab,bc...->ad...c', + (4, 5, 6, 7), + bias_shape=None, + rngs=nnx.Rngs(0), + ) + x = jnp.ones((8, 6, 7, 3, 4)) + y = model(x) + self.assertEqual(y.shape, (3, 8, 6, 7, 5)) + + class TestLayersSameGraph(parameterized.TestCase): @parameterized.product(