77from noxton .nn import ResidualLayerNorm
88
99
10+ def parallel_stabilized_simple (
11+ queries : Float [Array , "num_heads seq_len head_dim" ],
12+ keys : Float [Array , "num_heads seq_len head_dim" ],
13+ values : Float [Array , "num_heads seq_len head_dim" ],
14+ igate_preact : Float [Array , "num_heads seq_len" ],
15+ fgate_preact : Float [Array , "num_heads seq_len" ],
16+ lower_triangular_matrix : Float [Array , "seq_len seq_len" ] | None = None ,
17+ stabilize_rowwise : bool = True ,
18+ eps : float = 1e-6 ,
19+ ** kwargs ,
20+ ) -> Array :
21+ NH , S , DH = queries .shape
22+
23+ log_fgates = jax .nn .log_sigmoid (fgate_preact )
24+ if lower_triangular_matrix is None or lower_triangular_matrix .shape [0 ] < S :
25+ lower_triangular_matrix = jnp .tril (jnp .ones (shape = (S , S ), dtype = jnp .bool ))
26+
27+ assert lower_triangular_matrix is not None
28+
29+ log_fgates_cumsum = jnp .concatenate (
30+ (jnp .zeros ((NH , 1 , 1 )), jnp .cumsum (log_fgates , axis = 1 )), axis = 1
31+ )
32+ rep_log_fgates_cumsum = jnp .tile (log_fgates_cumsum , (1 , 1 , S + 1 ))
33+
34+ _log_fg_matrix = rep_log_fgates_cumsum - rep_log_fgates_cumsum .transpose (0 , 2 , 1 )
35+ log_fg_matrix = jnp .where (
36+ lower_triangular_matrix , _log_fg_matrix [:, 1 :, 1 :], - float ("inf" )
37+ )
38+ log_D_matrix = log_fg_matrix + igate_preact .transpose (0 , 2 , 1 )
39+ # D matrix stabilization
40+ if stabilize_rowwise :
41+ max_log_D = jnp .max (log_D_matrix , axis = - 1 , keepdims = True )
42+ else :
43+ max_log_D = jnp .expand_dims (
44+ jnp .max (log_D_matrix .reshape (NH , - 1 ), axis = - 1 , keepdims = True ), axis = - 1
45+ )
46+
47+ log_D_matrix_stabilized = log_D_matrix - max_log_D
48+ D_matrix = jnp .exp (log_D_matrix_stabilized )
49+
50+ keys_scaled = keys / jnp .sqrt (DH )
51+
52+ qk_matrix = queries @ keys_scaled .transpose (0 , 2 , 1 )
53+ C_matrix = qk_matrix * D_matrix
54+ normalizer = jnp .maximum (
55+ jnp .abs (C_matrix .sum (axis = - 1 , keepdims = True )), jnp .exp (- max_log_D )
56+ )
57+ C_matrix_normalized = C_matrix / (normalizer + eps )
58+ h_tilde_state = C_matrix_normalized @ values
59+
60+ return h_tilde_state
61+
62+
1063class mLSTMCell (eqx .Module ):
64+ max_seq_len : int
1165 embedding_dim : int
1266 num_heads : int
1367
@@ -20,11 +74,13 @@ def __init__(
2074 self ,
2175 embedding_dim : int ,
2276 num_heads : int ,
77+ max_seq_len : int ,
2378 key : PRNGKeyArray ,
2479 dtype : Any | None = None ,
2580 ) -> None :
2681 self .embedding_dim = embedding_dim
2782 self .num_heads = num_heads
83+ self .max_seq_len = max_seq_len
2884 key , ikey , fkey = jax .random .split (key , 3 )
2985
3086 igate = eqx .nn .Linear (3 * embedding_dim , num_heads , key = ikey , dtype = dtype )
@@ -63,14 +119,22 @@ def __call__(
63119 k = jnp .reshape (k , shape = (seq_len , self .num_heads , head_dim )).transpose (1 , 0 , 2 )
64120 v = jnp .reshape (v , shape = (seq_len , self .num_heads , head_dim )).transpose (1 , 0 , 2 )
65121
66- igate_preact = self .igate (if_gate_input )
122+ igate_preact = eqx . filter_vmap ( self .igate ) (if_gate_input )
67123 igate_preact = jnp .expand_dims (igate_preact .T , axis = - 1 )
68124
69- fgate_preact = self .fgate (if_gate_input )
125+ fgate_preact = eqx . filter_vmap ( self .fgate ) (if_gate_input )
70126 fgate_preact = jnp .expand_dims (fgate_preact .T , axis = - 1 )
71127
72- print (f"{ igate_preact .shape = } " )
73- print (f"{ fgate_preact .shape = } " )
128+ ltr = jnp .tril (
129+ jnp .ones (shape = (self .max_seq_len , self .max_seq_len ), dtype = jnp .bool )
130+ )
131+
132+ h_state = parallel_stabilized_simple (
133+ q , k , v , igate_preact , fgate_preact , lower_triangular_matrix = ltr
134+ )
135+ h_state = h_state .transpose (1 , 0 , 2 ).reshape (seq_len , - 1 )
136+ h_state_norm = eqx .filter_vmap (self .outnorm )(h_state )
137+ return h_state_norm
74138
75139
76140class mLSTMLayer (eqx .Module ):
0 commit comments