-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvit_class.py
More file actions
252 lines (204 loc) · 12 KB
/
vit_class.py
File metadata and controls
252 lines (204 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import matplotlib.pyplot as plt
import torch
import torchvision
from torch import nn
from torchvision import transforms
# 1. Create a class which subclasses nn.Module
class PatchEmbedding(nn.Module):
"""Turns a 2D input image into a 1D sequence learnable embedding vector.
Args:
in_channels (int): Number of color channels for the input images. Defaults to 3.
patch_size (int): Size of patches to convert input image into. Defaults to 16.
embedding_dim (int): Size of embedding to turn image into. Defaults to 768.
"""
# 2. Initialize the class with appropriate variables
def __init__(self,
in_channels:int=3,
patch_size:int=16,
embedding_dim:int=768):
super().__init__()
self.patch_size = patch_size
# 3. Create a layer to turn an image into patches
self.patcher = nn.Conv2d(in_channels=in_channels,
out_channels=embedding_dim,
kernel_size=patch_size,
stride=patch_size,
padding=0)
# 4. Create a layer to flatten the patch feature maps into a single dimension
self.flatten = nn.Flatten(start_dim=2, # only flatten the feature map dimensions into a single vector
end_dim=3)
# 5. Define the forward method
def forward(self, x):
# Create assertion to check that inputs are the correct shape
image_resolution = x.shape[-1]
assert image_resolution % self.patch_size == 0, f"Input image size must be divisble by patch size, image shape: {image_resolution}, patch size: {patch_size}"
# Perform the forward pass
x_patched = self.patcher(x)
x_flattened = self.flatten(x_patched)
# 6. Make sure the output shape has the right order
return x_flattened.permute(0, 2, 1) # adjust so the embedding is on the final dimension [batch_size, P^2•C, N] -> [batch_size, N, P^2•C]
# 1. Create a class that inherits from nn.Module
class MultiheadSelfAttentionBlock(nn.Module):
"""Creates a multi-head self-attention block ("MSA block" for short).
"""
# 2. Initialize the class with hyperparameters from Table 1
def __init__(self,
embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
num_heads:int=12, # Heads from Table 1 for ViT-Base
attn_dropout:float=0): # doesn't look like the paper uses any dropout in MSABlocks
super().__init__()
# 3. Create the Norm layer (LN)
self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
# 4. Create the Multi-Head Attention (MSA) layer
self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim,
num_heads=num_heads,
dropout=attn_dropout,
batch_first=True) # does our batch dimension come first?
# 5. Create a forward() method to pass the data throguh the layers
def forward(self, x):
x = self.layer_norm(x)
attn_output, _ = self.multihead_attn(query=x, # query embeddings
key=x, # key embeddings
value=x, # value embeddings
need_weights=False) # do we need the weights or just the layer outputs?
return attn_output
# 1. Create a class that inherits from nn.Module
class MLPBlock(nn.Module):
"""Creates a layer normalized multilayer perceptron block ("MLP block" for short)."""
# 2. Initialize the class with hyperparameters from Table 1 and Table 3
def __init__(self,
embedding_dim:int=768, # Hidden Size D from Table 1 for ViT-Base
mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
dropout:float=0.1): # Dropout from Table 3 for ViT-Base
super().__init__()
# 3. Create the Norm layer (LN)
self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
# 4. Create the Multilayer perceptron (MLP) layer(s)
self.mlp = nn.Sequential(
nn.Linear(in_features=embedding_dim,
out_features=mlp_size),
nn.GELU(), # "The MLP contains two layers with a GELU non-linearity (section 3.1)."
nn.Dropout(p=dropout),
nn.Linear(in_features=mlp_size, # needs to take same in_features as out_features of layer above
out_features=embedding_dim), # take back to embedding_dim
nn.Dropout(p=dropout) # "Dropout, when used, is applied after every dense layer.."
)
# 5. Create a forward() method to pass the data throguh the layers
def forward(self, x):
x = self.layer_norm(x)
x = self.mlp(x)
return x
# 1. Create a class that inherits from nn.Module
class TransformerEncoderBlock(nn.Module):
"""Creates a Transformer Encoder block."""
# 2. Initialize the class with hyperparameters from Table 1 and Table 3
def __init__(self,
embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
num_heads:int=12, # Heads from Table 1 for ViT-Base
mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
mlp_dropout:float=0.1, # Amount of dropout for dense layers from Table 3 for ViT-Base
attn_dropout:float=0): # Amount of dropout for attention layers
super().__init__()
# 3. Create MSA block (equation 2)
self.msa_block = MultiheadSelfAttentionBlock(embedding_dim=embedding_dim,
num_heads=num_heads,
attn_dropout=attn_dropout)
# 4. Create MLP block (equation 3)
self.mlp_block = MLPBlock(embedding_dim=embedding_dim,
mlp_size=mlp_size,
dropout=mlp_dropout)
# 5. Create a forward() method
def forward(self, x):
# 6. Create residual connection for MSA block (add the input to the output)
x = self.msa_block(x) + x
# 7. Create residual connection for MLP block (add the input to the output)
x = self.mlp_block(x) + x
return x
# Create an instance of ViT with adjusted parameters
# vit = ViT(
# img_size=224,
# patch_size=16,
# num_classes=len(class_names),
# num_transformer_layers=6, # Reduced from 12
# embedding_dim=512, # Reduced from 768
# mlp_size=2048, # Reduced from 3072
# num_heads=8, # Reduced from 12
# attn_dropout=0.1,
# mlp_dropout=0.1,
# embedding_dropout=0.1,
# )
# img_size:int=224, # Training resolution from Table 3 in ViT paper
# in_channels:int=3, # Number of channels in input image
# patch_size:int=16, # Patch size
# num_transformer_layers:int=12, # Layers from Table 1 for ViT-Base
# embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
# mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
# num_heads:int=12, # Heads from Table 1 for ViT-Base
# attn_dropout:float=0, # Dropout for attention projection
# mlp_dropout:float=0.1, # Dropout for dense/MLP layers
# embedding_dropout:float=0.1, # Dropout for patch and position embeddings
# num_classes:int=1000): # Default for ImageNet but can customize this
# 1. Create a ViT class that inherits from nn.Module
class ViT(nn.Module):
"""Creates a Vision Transformer architecture with ViT-Base hyperparameters by default."""
# 2. Initialize the class with hyperparameters from Table 1 and Table 3
def __init__(self,
img_size:int=224, # Training resolution from Table 3 in ViT paper
in_channels:int=3, # Number of channels in input image
patch_size:int=16, # Patch size
num_transformer_layers=6, # Reduced from 12 # Layers from Table 1 for ViT-Base
embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
mlp_size=2048, # Reduced from 3072 # MLP size from Table 1 for ViT-Base
num_heads=8, # Reduced from 12 # Heads from Table 1 for ViT-Base
attn_dropout:float=0, # Dropout for attention projection
mlp_dropout:float=0.1, # Dropout for dense/MLP layers
embedding_dropout:float=0.1, # Dropout for patch and position embeddings
num_classes:int=1000): # Default for ImageNet but can customize this
super().__init__() # don't forget the super().__init__()!
# 3. Make the image size is divisble by the patch size
assert img_size % patch_size == 0, f"Image size must be divisible by patch size, image size: {img_size}, patch size: {patch_size}."
# 4. Calculate number of patches (height * width/patch^2)
self.num_patches = (img_size * img_size) // patch_size**2
# 5. Create learnable class embedding (needs to go at front of sequence of patch embeddings)
self.class_embedding = nn.Parameter(data=torch.randn(1, 1, embedding_dim),
requires_grad=True)
# 6. Create learnable position embedding
self.position_embedding = nn.Parameter(data=torch.randn(1, self.num_patches+1, embedding_dim),
requires_grad=True)
# 7. Create embedding dropout value
self.embedding_dropout = nn.Dropout(p=embedding_dropout)
# 8. Create patch embedding layer
self.patch_embedding = PatchEmbedding(in_channels=in_channels,
patch_size=patch_size,
embedding_dim=embedding_dim)
# 9. Create Transformer Encoder blocks (we can stack Transformer Encoder blocks using nn.Sequential())
# Note: The "*" means "all"
self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim,
num_heads=num_heads,
mlp_size=mlp_size,
mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)])
# 10. Create classifier head
self.classifier = nn.Sequential(
nn.LayerNorm(normalized_shape=embedding_dim),
nn.Linear(in_features=embedding_dim,
out_features=num_classes)
)
# 11. Create a forward() method
def forward(self, x):
# 12. Get batch size
batch_size = x.shape[0]
# 13. Create class token embedding and expand it to match the batch size (equation 1)
class_token = self.class_embedding.expand(batch_size, -1, -1) # "-1" means to infer the dimension (try this line on its own)
# 14. Create patch embedding (equation 1)
x = self.patch_embedding(x)
# 15. Concat class embedding and patch embedding (equation 1)
x = torch.cat((class_token, x), dim=1)
# 16. Add position embedding to patch embedding (equation 1)
x = self.position_embedding + x
# 17. Run embedding dropout (Appendix B.1)
x = self.embedding_dropout(x)
# 18. Pass patch, position and class embedding through transformer encoder layers (equations 2 & 3)
x = self.transformer_encoder(x)
# 19. Put 0 index logit through classifier (equation 4)
x = self.classifier(x[:, 0]) # run on each sample in a batch at 0 index
return x