Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 275 additions & 0 deletions tests/nnx/nn/linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,281 @@ def linen_loss_function(variables):
)


class TestEinsum(parameterized.TestCase):

@parameterized.product(
einsum_str_input_kernel_bias_output=[
# matrix multiply: (batch, in) x (in, out) -> (batch, out)
('bi,io->bo', (4, 3), (3, 5), (5,), (4, 5)),
# batched matmul with extra dims
('nta,hab->nthb', (16, 11, 2), (8, 2, 4), (8, 4), (16, 11, 8, 4)),
# multi-head attention-like: (batch, seq, heads, dim)
('bshd,hdo->bsho', (2, 8, 4, 16), (4, 16, 32), (4, 32),
(2, 8, 4, 32)),
# contraction over multiple dims
('defab,bcef->adefc', (8, 6, 7, 3, 4), (4, 5, 6, 7), (6, 7, 5),
(3, 8, 6, 7, 5)),
# simple 2D without bias
('ij,jk->ik', (2, 3), (3, 4), None, (2, 4)),
],
)
def test_output_shape(self, einsum_str_input_kernel_bias_output):
einsum_str, input_shape, kernel_shape, bias_shape, expected_shape = (
einsum_str_input_kernel_bias_output
)
model = nnx.Einsum(
einsum_str, kernel_shape, bias_shape, rngs=nnx.Rngs(0)
)
x = jnp.ones(input_shape)
y = model(x)
self.assertEqual(y.shape, expected_shape)
self.assertEqual(model.kernel.shape, kernel_shape)
if bias_shape is not None:
self.assertIsNotNone(model.bias)
self.assertEqual(model.bias.shape, bias_shape)
else:
self.assertIsNone(model.bias)

@parameterized.product(
dtype=[None, jnp.float32, jnp.float16],
param_dtype=[jnp.float32, jnp.float16],
preferred_element_type=[None, jnp.float32],
)
def test_dtypes(self, dtype, param_dtype, preferred_element_type):
model = nnx.Einsum(
'bi,io->bo',
(3, 5),
(5,),
dtype=dtype,
param_dtype=param_dtype,
preferred_element_type=preferred_element_type,
rngs=nnx.Rngs(0),
)
self.assertEqual(model.kernel.dtype, param_dtype)
self.assertEqual(model.bias.dtype, param_dtype)

x = jnp.ones((2, 3))
y = model(x)
if preferred_element_type is not None:
self.assertEqual(y.dtype, preferred_element_type)
elif dtype is not None:
self.assertEqual(y.dtype, dtype)
else:
# dtype=None: output dtype inferred from input (float32) and params
expected_dtype = jnp.result_type(jnp.float32, param_dtype)
self.assertEqual(y.dtype, expected_dtype)

@parameterized.product(
dtype=[jnp.float32, jnp.float16],
param_dtype=[jnp.float32, jnp.float16],
)
def test_no_bias(self, dtype, param_dtype):
model = nnx.Einsum(
'bi,io->bo', (3, 5), bias_shape=None,
dtype=dtype, param_dtype=param_dtype,
kernel_init=nnx.initializers.ones, rngs=nnx.Rngs(0),
)
x = jnp.ones((2, 3), dtype=dtype)
y = model(x)

# kernel=ones: y[b,o] = sum_i(x[b,i] * 1) = 3.0 for all b,o
rtol = (
1e-3
if dtype == jnp.float16 or param_dtype == jnp.float16
else 1e-6
)
np.testing.assert_allclose(
y, jnp.full((2, 5), 3.0, dtype=y.dtype), rtol=rtol,
)

def test_bias_addition_multidim(self):
model = nnx.Einsum(
'defab,bcef->adefc',
(4, 5, 6, 7),
(6, 7, 5),
kernel_init=nnx.initializers.zeros,
bias_init=nnx.initializers.ones,
rngs=nnx.Rngs(0),
)
x = jnp.zeros((8, 6, 7, 3, 4))
# Verify _infer_broadcasted_bias_shape computes correct shape
# output='adefc': bias dims (e,f,c) from rhs -> [1, 1, 6, 7, 5]
broadcasted_shape = model._infer_broadcasted_bias_shape(
'defab,bcef->adefc', x, model.kernel[...]
)
self.assertEqual(broadcasted_shape, [1, 1, 6, 7, 5])
# kernel=0 means einsum result is 0, output should be broadcast bias
y = model(x)
self.assertEqual(y.shape, (3, 8, 6, 7, 5))
np.testing.assert_allclose(y, jnp.ones_like(y), rtol=1e-6)

def test_bias_addition_simple(self):
model = nnx.Einsum(
'bi,io->bo', (3, 5), (5,), rngs=nnx.Rngs(0)
)
x = jnp.ones((4, 3))
y_with_bias = model(x)

no_bias_result = jnp.einsum('bi,io->bo', x, model.kernel[...])
diff = y_with_bias - no_bias_result
expected_diff = jnp.broadcast_to(
model.bias[...], y_with_bias.shape
)
np.testing.assert_allclose(diff, expected_diff, rtol=1e-6)

def test_bias_broadcast_with_ellipsis(self):
model = nnx.Einsum(
'd...ab,bc...->ad...c',
(4, 5, 6, 7),
(5, 6, 7),
kernel_init=nnx.initializers.zeros,
bias_init=nnx.initializers.ones,
rngs=nnx.Rngs(0),
)
x = jnp.zeros((8, 6, 7, 3, 4))
# Verify broadcasted bias shape for ellipsis equation
broadcasted_shape = model._infer_broadcasted_bias_shape(
'd...ab,bc...->ad...c', x, model.kernel[...]
)
self.assertEqual(broadcasted_shape, [1, 1, 6, 7, 5])
y = model(x)
self.assertEqual(y.shape, (3, 8, 6, 7, 5))
np.testing.assert_allclose(y, jnp.ones_like(y), rtol=1e-6)

def test_spaces_in_einsum_str(self):
model_spaces = nnx.Einsum(
'b i, i o -> b o', (3, 5), bias_shape=None,
kernel_init=nnx.initializers.ones, rngs=nnx.Rngs(0),
)
model_clean = nnx.Einsum(
'bi,io->bo', (3, 5), bias_shape=None,
kernel_init=nnx.initializers.ones, rngs=nnx.Rngs(0),
)
x = jnp.ones((2, 3))
np.testing.assert_array_equal(
model_spaces(x), model_clean(x)
)

def test_einsum_str_missing_arrow_raises(self):
with self.assertRaisesRegex(
ValueError, 'must be explicit and include'
):
nnx.Einsum('bi,io', (3, 5), rngs=nnx.Rngs(0))

def test_einsum_str_wrong_operand_count_raises(self):
with self.assertRaisesRegex(
ValueError, 'exactly two operands'
):
nnx.Einsum('a,b,c->d', (3,), rngs=nnx.Rngs(0))

@parameterized.product(
precision=[Precision.DEFAULT, Precision.HIGH, Precision.HIGHEST],
)
def test_precision(self, precision):
received = []

def capturing_einsum(
*args, precision=None, out_sharding=None, **kwargs,
):
received.append(precision)
return jnp.einsum(*args, precision=precision, **kwargs)

model = nnx.Einsum(
'bi,io->bo', (3, 5), bias_shape=None,
precision=precision,
einsum_op=capturing_einsum,
rngs=nnx.Rngs(0),
)
model(jnp.ones((2, 3)))
self.assertLen(received, 1)
self.assertEqual(received[0], precision)

@parameterized.product(
bias_shape=[(5,), None],
)
def test_gradient_flow(self, bias_shape):
model = nnx.Einsum(
'bi,io->bo', (3, 5), bias_shape,
kernel_init=nnx.initializers.ones,
bias_init=nnx.initializers.zeros,
rngs=nnx.Rngs(0),
)
x = jnp.ones((2, 3))

grads = jax.grad(lambda m: m(x).sum())(model)
# d(sum(x @ kernel + bias))/d(kernel[j,k]) = sum_i(x[i,j]) = 2.0
np.testing.assert_allclose(
grads.kernel[...], jnp.full((3, 5), 2.0), rtol=1e-6
)
if bias_shape is not None:
# d(sum(bias))/d(bias[k]) = batch_size = 2.0
np.testing.assert_allclose(
grads.bias[...], jnp.full(bias_shape, 2.0), rtol=1e-6
)
else:
self.assertIsNone(grads.bias)

def test_einsum_str_call_override(self):
received = []

def capturing_einsum(
expr, *args, out_sharding=None, **kwargs,
):
received.append(expr)
return jnp.einsum(expr, *args, **kwargs)

# Constructor uses 'ab,bc->ac', call overrides with 'bi,io->bo'
model = nnx.Einsum(
'ab,bc->ac', (3, 5), bias_shape=None,
einsum_op=capturing_einsum, rngs=nnx.Rngs(0),
)
x = jnp.ones((2, 3))

# Without override: constructor string is used
model(x)
self.assertLen(received, 1)
self.assertEqual(received[0], 'ab,bc->ac')

# With override: call-time string must win
model(x, einsum_str='bi,io->bo')
self.assertLen(received, 2)
self.assertEqual(received[1], 'bi,io->bo')

def test_custom_initializers(self):
model = nnx.Einsum(
'bi,io->bo',
(3, 5),
(5,),
kernel_init=nnx.initializers.ones,
bias_init=nnx.initializers.zeros,
param_dtype=jnp.float32,
rngs=nnx.Rngs(0),
)
np.testing.assert_array_equal(
model.kernel[...], jnp.ones((3, 5))
)
np.testing.assert_array_equal(
model.bias[...], jnp.zeros((5,))
)
# Verify initializers affect forward pass
x = jnp.ones((2, 3))
y = model(x)
# kernel=ones, bias=zeros: y = x @ ones + 0 = [[3,3,3,3,3],...]
expected = jnp.full((2, 5), 3.0)
np.testing.assert_allclose(y, expected, rtol=1e-6)

def test_ellipsis_einsum_str(self):
model = nnx.Einsum(
'd...ab,bc...->ad...c',
(4, 5, 6, 7),
bias_shape=None,
rngs=nnx.Rngs(0),
)
x = jnp.ones((8, 6, 7, 3, 4))
y = model(x)
self.assertEqual(y.shape, (3, 8, 6, 7, 5))


class TestLayersSameGraph(parameterized.TestCase):

@parameterized.product(
Expand Down