Skip to content

Commit 351db2e

Browse files
mlstm call
1 parent 36717ce commit 351db2e

File tree

4 files changed

+1478
-87
lines changed

4 files changed

+1478
-87
lines changed

main.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
1-
import equinox as eqx
21
import jax
2+
import jax.numpy as jnp
33

4+
from noxton.nn import mLSTMCell
45

5-
class Model(eqx.Module):
6-
inference: bool = eqx.field(static=True)
7-
lin: eqx.nn.Linear
6+
max_seq_len = 4
7+
embed_dim = 16
8+
num_heads = 8
9+
seq_len = 4
810

9-
def __init__(self):
10-
self.inference = False
11-
self.lin = eqx.nn.Linear(10, 10, key=jax.random.key(2))
1211

12+
cell = mLSTMCell(embed_dim, num_heads, key=jax.random.key(22), max_seq_len=4)
1313

14-
model = Model()
15-
print(model)
14+
q, k, v = (
15+
jnp.ones(shape=(seq_len, embed_dim)),
16+
jnp.ones(shape=(seq_len, embed_dim)),
17+
jnp.ones(shape=(seq_len, embed_dim)),
18+
)
1619

17-
model = eqx.nn.inference_mode(model)
18-
print(model)
20+
21+
cell(q, k, v)

noxton/nn/xlstm.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,61 @@
77
from noxton.nn import ResidualLayerNorm
88

99

10+
def parallel_stabilized_simple(
11+
queries: Float[Array, "num_heads seq_len head_dim"],
12+
keys: Float[Array, "num_heads seq_len head_dim"],
13+
values: Float[Array, "num_heads seq_len head_dim"],
14+
igate_preact: Float[Array, "num_heads seq_len"],
15+
fgate_preact: Float[Array, "num_heads seq_len"],
16+
lower_triangular_matrix: Float[Array, "seq_len seq_len"] | None = None,
17+
stabilize_rowwise: bool = True,
18+
eps: float = 1e-6,
19+
**kwargs,
20+
) -> Array:
21+
NH, S, DH = queries.shape
22+
23+
log_fgates = jax.nn.log_sigmoid(fgate_preact)
24+
if lower_triangular_matrix is None or lower_triangular_matrix.shape[0] < S:
25+
lower_triangular_matrix = jnp.tril(jnp.ones(shape=(S, S), dtype=jnp.bool))
26+
27+
assert lower_triangular_matrix is not None
28+
29+
log_fgates_cumsum = jnp.concatenate(
30+
(jnp.zeros((NH, 1, 1)), jnp.cumsum(log_fgates, axis=1)), axis=1
31+
)
32+
rep_log_fgates_cumsum = jnp.tile(log_fgates_cumsum, (1, 1, S + 1))
33+
34+
_log_fg_matrix = rep_log_fgates_cumsum - rep_log_fgates_cumsum.transpose(0, 2, 1)
35+
log_fg_matrix = jnp.where(
36+
lower_triangular_matrix, _log_fg_matrix[:, 1:, 1:], -float("inf")
37+
)
38+
log_D_matrix = log_fg_matrix + igate_preact.transpose(0, 2, 1)
39+
# D matrix stabilization
40+
if stabilize_rowwise:
41+
max_log_D = jnp.max(log_D_matrix, axis=-1, keepdims=True)
42+
else:
43+
max_log_D = jnp.expand_dims(
44+
jnp.max(log_D_matrix.reshape(NH, -1), axis=-1, keepdims=True), axis=-1
45+
)
46+
47+
log_D_matrix_stabilized = log_D_matrix - max_log_D
48+
D_matrix = jnp.exp(log_D_matrix_stabilized)
49+
50+
keys_scaled = keys / jnp.sqrt(DH)
51+
52+
qk_matrix = queries @ keys_scaled.transpose(0, 2, 1)
53+
C_matrix = qk_matrix * D_matrix
54+
normalizer = jnp.maximum(
55+
jnp.abs(C_matrix.sum(axis=-1, keepdims=True)), jnp.exp(-max_log_D)
56+
)
57+
C_matrix_normalized = C_matrix / (normalizer + eps)
58+
h_tilde_state = C_matrix_normalized @ values
59+
60+
return h_tilde_state
61+
62+
1063
class mLSTMCell(eqx.Module):
64+
max_seq_len: int
1165
embedding_dim: int
1266
num_heads: int
1367

@@ -20,11 +74,13 @@ def __init__(
2074
self,
2175
embedding_dim: int,
2276
num_heads: int,
77+
max_seq_len: int,
2378
key: PRNGKeyArray,
2479
dtype: Any | None = None,
2580
) -> None:
2681
self.embedding_dim = embedding_dim
2782
self.num_heads = num_heads
83+
self.max_seq_len = max_seq_len
2884
key, ikey, fkey = jax.random.split(key, 3)
2985

3086
igate = eqx.nn.Linear(3 * embedding_dim, num_heads, key=ikey, dtype=dtype)
@@ -63,14 +119,22 @@ def __call__(
63119
k = jnp.reshape(k, shape=(seq_len, self.num_heads, head_dim)).transpose(1, 0, 2)
64120
v = jnp.reshape(v, shape=(seq_len, self.num_heads, head_dim)).transpose(1, 0, 2)
65121

66-
igate_preact = self.igate(if_gate_input)
122+
igate_preact = eqx.filter_vmap(self.igate)(if_gate_input)
67123
igate_preact = jnp.expand_dims(igate_preact.T, axis=-1)
68124

69-
fgate_preact = self.fgate(if_gate_input)
125+
fgate_preact = eqx.filter_vmap(self.fgate)(if_gate_input)
70126
fgate_preact = jnp.expand_dims(fgate_preact.T, axis=-1)
71127

72-
print(f"{igate_preact.shape=}")
73-
print(f"{fgate_preact.shape=}")
128+
ltr = jnp.tril(
129+
jnp.ones(shape=(self.max_seq_len, self.max_seq_len), dtype=jnp.bool)
130+
)
131+
132+
h_state = parallel_stabilized_simple(
133+
q, k, v, igate_preact, fgate_preact, lower_triangular_matrix=ltr
134+
)
135+
h_state = h_state.transpose(1, 0, 2).reshape(seq_len, -1)
136+
h_state_norm = eqx.filter_vmap(self.outnorm)(h_state)
137+
return h_state_norm
74138

75139

76140
class mLSTMLayer(eqx.Module):

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ dependencies = [
88
"equinox>=0.13.4",
99
"ftfy>=6.3.1",
1010
"jax>=0.9.0.1",
11+
"numpix>=0.9.10",
1112
"statedict2pytree>=2.0.2",
1213
]
1314

0 commit comments

Comments
 (0)