Skip to content

Commit 36717ce

Browse files
mlstm start
1 parent a982ae1 commit 36717ce

File tree

4 files changed

+180
-3
lines changed

4 files changed

+180
-3
lines changed

noxton/nn/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
from .convolution import ConvNormActivation
44
from .embedding import EmbeddingBag, EmbeddingWithPadding
55
from .mamba import Mamba, MambaBlock, SelectiveStateSpaceModel
6-
from .normalization import BatchNorm, LayerNorm, LocalResponseNormalization
6+
from .normalization import (
7+
BatchNorm,
8+
LayerNorm,
9+
LocalResponseNormalization,
10+
ResidualLayerNorm,
11+
)
712
from .regularization import StochasticDepth
813
from .sequential import BatchedLinear
914
from .state_space import SelectiveStateSpace
@@ -15,6 +20,7 @@
1520
TransformerEncoderLayer,
1621
VisionTransformer,
1722
)
23+
from .xlstm import mLSTMCell
1824

1925
__all__ = [
2026
"AbstractNorm",
@@ -26,6 +32,7 @@
2632
"EmbeddingWithPadding",
2733
"BatchNorm",
2834
"LayerNorm",
35+
"ResidualLayerNorm",
2936
"LocalResponseNormalization",
3037
"StochasticDepth",
3138
"BatchedLinear",
@@ -39,4 +46,5 @@
3946
"Mamba",
4047
"SelectiveStateSpaceModel",
4148
"MambaBlock",
49+
"mLSTMCell",
4250
]

noxton/nn/normalization.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import jax
33
import jax.numpy as jnp
44
from beartype.typing import Any, Hashable, Sequence
5+
from equinox import Module, field
56
from equinox.nn import State
67
from jaxtyping import Array, Float, PRNGKeyArray
78

@@ -73,8 +74,8 @@ def __init__(
7374
eps: float = 1e-5,
7475
momentum: float = 0.1,
7576
affine: bool = True,
76-
inference: bool = False,
7777
dtype: Any | None = None,
78+
inference: bool = False,
7879
):
7980
if dtype is None:
8081
dtype = default_floating_dtype()
@@ -364,3 +365,100 @@ def __call__(self, x: Array, *_, key: PRNGKeyArray | None = None, **__) -> Array
364365

365366
out = out.astype(orig_dtype)
366367
return out
368+
369+
370+
class ResidualLayerNorm(Module):
371+
"""Layer normalisation with a residual scale parameter.
372+
373+
Normalises the input by subtracting the mean and dividing by the standard
374+
deviation computed over the entire array. The learnable affine scale
375+
parameter is formulated as a residual ``1 + weight``, where ``weight`` is
376+
initialised to zero.
377+
378+
Unlike ``LayerNorm``, this module expects the input to exactly match the
379+
configured ``shape`` and does not automatically broadcast over leading
380+
batch dimensions; use ``jax.vmap`` for batched inputs.
381+
382+
Computation is performed at a higher precision (at least ``float32``) and
383+
the result is cast back to the original dtype.
384+
385+
Args:
386+
shape: The exact shape of the unbatched input array. Pass a single
387+
``int`` for the common 1-D case.
388+
eps: Small constant added to the variance for numerical stability.
389+
Defaults to ``1e-5``.
390+
use_weight: If ``True``, learn a per-element residual scale parameter
391+
initialised to ``0``. Defaults to ``True``.
392+
use_bias: If ``True``, learn a per-element bias parameter initialised
393+
to ``0``. Defaults to ``False``.
394+
dtype: Floating-point dtype for the affine parameters. Defaults to
395+
``None``.
396+
397+
Raises:
398+
ValueError: If the input shape does not exactly match ``shape``.
399+
400+
Example:
401+
>>> import jax
402+
>>> import jax.numpy as jnp
403+
>>> rln = ResidualLayerNorm(shape=64)
404+
>>> x = jnp.ones((10, 64))
405+
>>> jax.vmap(rln)(x).shape
406+
(10, 64)
407+
"""
408+
409+
shape: tuple[int, ...] = field(static=True)
410+
eps: float = field(static=True)
411+
use_weight: bool = field(static=True)
412+
use_bias: bool = field(static=True)
413+
weight: Float[Array, "*shape"] | None
414+
bias: Float[Array, "*shape"] | None
415+
416+
def __init__(
417+
self,
418+
shape: int | Sequence[int],
419+
eps: float = 1e-5,
420+
use_weight: bool = True,
421+
use_bias: bool = False,
422+
dtype=None,
423+
):
424+
if isinstance(shape, int):
425+
shape = (shape,)
426+
else:
427+
shape = tuple(shape)
428+
self.shape = shape
429+
self.eps = eps
430+
self.use_weight = use_weight
431+
self.use_bias = use_bias
432+
self.weight = jnp.zeros(shape, dtype=dtype) if use_weight else None
433+
self.bias = jnp.zeros(shape, dtype=dtype) if use_bias else None
434+
435+
def __call__(
436+
self,
437+
x: Float[Array, "*shape"],
438+
*,
439+
key: PRNGKeyArray | None = None,
440+
) -> Array:
441+
if x.shape != self.shape:
442+
raise ValueError(
443+
f"Expected shape {self.shape}, got {x.shape}. You might need jax.vmap."
444+
)
445+
446+
orig_dtype = x.dtype
447+
with jax.numpy_dtype_promotion("standard"):
448+
dtype = jnp.result_type(x.dtype, jnp.float32)
449+
450+
x = x.astype(dtype)
451+
mean = jnp.mean(x, keepdims=True)
452+
variance = jnp.var(x, keepdims=True)
453+
variance = jnp.maximum(0.0, variance)
454+
inv = jax.lax.rsqrt(variance + self.eps)
455+
out = (x - mean) * inv
456+
457+
if self.use_weight:
458+
assert self.weight is not None
459+
out = (1.0 + self.weight.astype(dtype)) * out
460+
if self.use_bias:
461+
assert self.bias is not None
462+
out = out + self.bias.astype(dtype)
463+
464+
return out.astype(orig_dtype)

noxton/nn/xlstm.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,76 @@
11
import equinox as eqx
2+
import jax
3+
import jax.numpy as jnp
4+
from beartype.typing import Any
5+
from jaxtyping import Array, Float, PRNGKeyArray
6+
7+
from noxton.nn import ResidualLayerNorm
28

39

410
class mLSTMCell(eqx.Module):
5-
pass
11+
embedding_dim: int
12+
num_heads: int
13+
14+
igate: eqx.nn.Linear
15+
fgate: eqx.nn.Linear
16+
17+
outnorm: ResidualLayerNorm
18+
19+
def __init__(
20+
self,
21+
embedding_dim: int,
22+
num_heads: int,
23+
key: PRNGKeyArray,
24+
dtype: Any | None = None,
25+
) -> None:
26+
self.embedding_dim = embedding_dim
27+
self.num_heads = num_heads
28+
key, ikey, fkey = jax.random.split(key, 3)
29+
30+
igate = eqx.nn.Linear(3 * embedding_dim, num_heads, key=ikey, dtype=dtype)
31+
igate = eqx.tree_at(
32+
lambda l: l.weight, igate, jnp.zeros_like(igate.weight, dtype=dtype)
33+
)
34+
self.igate = eqx.tree_at(
35+
lambda l: l.bias,
36+
igate,
37+
jnp.linspace(start=3.0, stop=6.0, num=len(igate.bias), dtype=dtype),
38+
)
39+
40+
fgate = eqx.nn.Linear(3 * embedding_dim, num_heads, key=fkey, dtype=dtype)
41+
fgate = eqx.tree_at(
42+
lambda l: l.weight, fgate, jnp.zeros_like(fgate.weight, dtype=dtype)
43+
)
44+
key, subkey = jax.random.split(key)
45+
self.fgate = eqx.tree_at(
46+
lambda l: l.bias,
47+
fgate,
48+
jnp.sqrt(0.1) * jax.random.normal(key=subkey, shape=fgate.bias.shape),
49+
)
50+
51+
self.outnorm = ResidualLayerNorm(embedding_dim, use_bias=False, dtype=dtype)
52+
53+
def __call__(
54+
self,
55+
q: Float[Array, "seq_len embed_dim"],
56+
k: Float[Array, "seq_len embed_dim"],
57+
v: Float[Array, "seq_len embed_dim"],
58+
):
59+
seq_len, _ = q.shape
60+
if_gate_input = jnp.concatenate((q, k, v), axis=1)
61+
head_dim = self.embedding_dim // self.num_heads
62+
q = jnp.reshape(q, shape=(seq_len, self.num_heads, head_dim)).transpose(1, 0, 2)
63+
k = jnp.reshape(k, shape=(seq_len, self.num_heads, head_dim)).transpose(1, 0, 2)
64+
v = jnp.reshape(v, shape=(seq_len, self.num_heads, head_dim)).transpose(1, 0, 2)
65+
66+
igate_preact = self.igate(if_gate_input)
67+
igate_preact = jnp.expand_dims(igate_preact.T, axis=-1)
68+
69+
fgate_preact = self.fgate(if_gate_input)
70+
fgate_preact = jnp.expand_dims(fgate_preact.T, axis=-1)
71+
72+
print(f"{igate_preact.shape=}")
73+
print(f"{fgate_preact.shape=}")
674

775

876
class mLSTMLayer(eqx.Module):

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,6 @@ dev = [
2727
"torch>=2.10.0",
2828
"torchvision>=0.25.0",
2929
]
30+
31+
[tool.ruff.lint]
32+
ignore = ["E741", "F722"]

0 commit comments

Comments
 (0)