Skip to content

Commit 39fd9ac

Browse files
committed
for n-dimensional vit, have a method for fetching muon friendly parameters
1 parent 3becf08 commit 39fd9ac

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "vit-pytorch"
7-
version = "1.12.4"
7+
version = "1.12.5"
88
description = "Vision Transformer (ViT) - Pytorch"
99
readme = { file = "README.md", content-type = "text/markdown" }
1010
license = { file = "LICENSE" }

vit_pytorch/vit_nd_rotary.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,16 +126,18 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., rotary_emb = Non
126126
self.attend = nn.Softmax(dim = -1)
127127
self.dropout = nn.Dropout(dropout)
128128

129-
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
130-
129+
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
130+
self.to_v = nn.Linear(dim, inner_dim, bias = False)
131+
131132
self.to_out = nn.Sequential(
132133
nn.Linear(inner_dim, dim),
133134
nn.Dropout(dropout)
134135
) if project_out else nn.Identity()
135136

136137
def forward(self, x, pos = None):
137138
x = self.norm(x)
138-
qkv = self.to_qkv(x).chunk(3, dim = -1)
139+
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
140+
139141
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
140142

141143
# Apply rotary embeddings if available
@@ -245,6 +247,23 @@ def __init__(
245247
self.to_latent = nn.Identity()
246248
self.mlp_head = nn.Linear(dim, num_classes)
247249

250+
def muon_parameters(self):
251+
params = []
252+
253+
for m in self.modules():
254+
if isinstance(m, Attention):
255+
params.extend([
256+
m.to_v.weight,
257+
m.to_out[0].weight
258+
])
259+
elif isinstance(m, FeedForward):
260+
params.extend([
261+
m.net[1].weight,
262+
m.net[-2].weight
263+
])
264+
265+
return params
266+
248267
def forward(
249268
self,
250269
x,

0 commit comments

Comments
 (0)