Skip to content

Commit 91b4a21

Browse files
author
Guillaume SANCHEZ
committed
vit
1 parent f05f98d commit 91b4a21

File tree

3 files changed

+116
-1
lines changed

3 files changed

+116
-1
lines changed

torchelie/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@
1717
from .alexnet import *
1818
from .mlpmixer import *
1919
from .convnext import *
20-
#from .poolformer import *
20+
from .vit import ViTTrunk
21+
# from .poolformer import *

torchelie/models/vit.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import torch
2+
import torch.nn as nn
3+
from ..nn.transformer import ViTBlock
4+
5+
6+
class ViTTrunk(nn.Module):
7+
"""
8+
Vision Transformer (ViT) trunk that processes a sequence of patch embeddings with positional encoding
9+
and optional learnable registers, using a stack of ViTBlock layers.
10+
11+
Args:
12+
seq_len (int): Length of the input sequence (number of patches).
13+
d_model (int): Dimension of the model.
14+
num_layers (int): Number of transformer blocks.
15+
num_heads (int): Number of attention heads.
16+
num_registers (int, optional): Number of learnable registers to prepend to the sequence. Default: 10.
17+
18+
Forward Args:
19+
x (Tensor): Input tensor of shape [B, C, H/P, W/P], where P is the patch size.
20+
21+
Returns:
22+
Tensor: Output tensor of shape [B, C, H/P, W/P].
23+
"""
24+
25+
def __init__(self, seq_len, d_model, num_layers, num_heads, num_registers=10):
26+
super().__init__()
27+
self.trunk = nn.ModuleList(
28+
[ViTBlock(d_model, num_heads) for _ in range(num_layers)]
29+
)
30+
self.pos_enc = nn.Parameter(torch.zeros(seq_len, d_model))
31+
self.registers = nn.Parameter(
32+
torch.randn(num_registers, d_model) / (d_model**0.5)
33+
)
34+
35+
def forward(self, x):
36+
"""
37+
Forward pass for the ViTTrunk.
38+
39+
Args:
40+
x (Tensor): Input tensor of shape [B, C, H/P, W/P].
41+
42+
Returns:
43+
Tensor: Output tensor of shape [B, C, H/P, W/P].
44+
"""
45+
# x: [B,C,H/P,W/P]
46+
B, C, Hp, Wp = x.shape
47+
x = x.view(B, C, Hp * Wp).permute(0, 2, 1)
48+
x = x + self.pos_enc
49+
# x: [B, L, C]
50+
x = torch.cat([self.registers.unsqueeze(0).expand(B, -1, -1), x], dim=1)
51+
for block in self.trunk:
52+
x = block(x)
53+
54+
x = x[:, len(self.registers) :, :]
55+
# x = F.gelu(x)
56+
x = x.permute(0, 2, 1).reshape(B, C, Hp, Wp)
57+
# x: [B,C,H/P,W/P]
58+
return x

torchelie/nn/transformer.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import torch.nn as nn
44
import torchelie.utils as tu
55
from .conv import Conv1x1
6+
from ..nn.condseq import CondSeq
7+
from ..nn.llm import SelfAttention
68
from .functional.transformer import local_attention_2d
79

810
from typing import Optional
@@ -75,3 +77,57 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7577

7678
x = self.out(x)
7779
return x
80+
81+
82+
class ViTBlock(nn.Module):
83+
"""
84+
Vision Transformer (ViT) block consisting of a self-attention layer and a feed-forward MLP,
85+
each followed by RMS normalization and gated residual connections.
86+
87+
Args:
88+
d_model (int): Dimension of the model.
89+
num_heads (int): Number of attention heads.
90+
91+
Forward Args:
92+
x (Tensor): Input tensor of shape [B, L, d_model].
93+
z (Any): Optional conditioning input for CondSeq modules.
94+
95+
Returns:
96+
Tensor: Output tensor of shape [B, L, d_model].
97+
"""
98+
99+
def __init__(self, d_model, num_heads):
100+
super().__init__()
101+
self.sa = CondSeq(
102+
nn.RMSNorm(d_model),
103+
SelfAttention(
104+
d_model,
105+
num_heads,
106+
head_size=d_model // num_heads,
107+
causal=False,
108+
rotary=True,
109+
),
110+
)
111+
self.mlp = CondSeq(
112+
nn.RMSNorm(d_model),
113+
tu.kaiming(nn.Linear(d_model, 4 * d_model)),
114+
nn.GELU(),
115+
tu.kaiming(nn.Linear(4 * d_model, d_model)),
116+
)
117+
self.g1 = tu.kaiming(nn.Linear(d_model, d_model))
118+
self.g2 = tu.kaiming(nn.Linear(d_model, d_model))
119+
120+
def forward(self, x, z):
121+
"""
122+
Forward pass for the ViTBlock.
123+
124+
Args:
125+
x (Tensor): Input tensor of shape [B, L, d_model].
126+
z (Any): Optional conditioning input for CondSeq modules.
127+
128+
Returns:
129+
Tensor: Output tensor of shape [B, L, d_model].
130+
"""
131+
x = self.sa(x, z) * torch.tanh(self.g1(x)) + x
132+
x = self.mlp(x, z) * torch.tanh(self.g2(x)) + x
133+
return x

0 commit comments

Comments
 (0)