@@ -126,16 +126,18 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., rotary_emb = Non
126126 self .attend = nn .Softmax (dim = - 1 )
127127 self .dropout = nn .Dropout (dropout )
128128
129- self .to_qkv = nn .Linear (dim , inner_dim * 3 , bias = False )
130-
129+ self .to_qk = nn .Linear (dim , inner_dim * 2 , bias = False )
130+ self .to_v = nn .Linear (dim , inner_dim , bias = False )
131+
131132 self .to_out = nn .Sequential (
132133 nn .Linear (inner_dim , dim ),
133134 nn .Dropout (dropout )
134135 ) if project_out else nn .Identity ()
135136
136137 def forward (self , x , pos = None ):
137138 x = self .norm (x )
138- qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
139+ qkv = (* self .to_qk (x ).chunk (2 , dim = - 1 ), self .to_v (x ))
140+
139141 q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = self .heads ), qkv )
140142
141143 # Apply rotary embeddings if available
@@ -245,6 +247,23 @@ def __init__(
245247 self .to_latent = nn .Identity ()
246248 self .mlp_head = nn .Linear (dim , num_classes )
247249
250+ def muon_parameters (self ):
251+ params = []
252+
253+ for m in self .modules ():
254+ if isinstance (m , Attention ):
255+ params .extend ([
256+ m .to_v .weight ,
257+ m .to_out [0 ].weight
258+ ])
259+ elif isinstance (m , FeedForward ):
260+ params .extend ([
261+ m .net [1 ].weight ,
262+ m .net [- 2 ].weight
263+ ])
264+
265+ return params
266+
248267 def forward (
249268 self ,
250269 x ,
0 commit comments