-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodeling_siglip.py
More file actions
243 lines (197 loc) · 13 KB
/
modeling_siglip.py
File metadata and controls
243 lines (197 loc) · 13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
from typing import Optional, Tuple
import torch
import torch.nn as nn
class SiglipVisionConfig:
"""Configuration class to store the configuration of a `SiglipVisionModel`."""
def __init__(
self,
hidden_size: int = 768,
intermediate_size: int = 3072,
num_hidden_layers: int = 12,
num_attention_heads: int = 12,
num_channels: int = 3,
image_size: int = 224,
patch_size: int = 16,
layer_norm_eps: float = 1e-6,
attention_dropout: float = 0.0,
num_image_tokens: int = None,
**kwargs
):
super().__init__()
self.hidden_size = hidden_size # The size of the embedding vector
self.intermediate_size = intermediate_size # The size of the "intermediate" (i.e., feed-forward) layer in the Transformer
self.num_hidden_layers = num_hidden_layers # The number of hidden layers of this Vision Transformer
self.num_attention_heads = num_attention_heads # The number of attention heads for each attention layer in the Transformer
self.num_channels = num_channels # The number of input channels in the image (3 for RGB, 1 for grayscale)
self.image_size = image_size # The height/width of the input image to the model
self.patch_size = patch_size # The size (height and width) of each patch (used to divide the input image into patches)
self.attention_dropout = attention_dropout # The dropout ratio for the attention probabilities
self.layer_norm_eps = layer_norm_eps # The epsilon used by the layer normalization layers
self.num_image_tokens = num_image_tokens # Indicates how many image embbedings we will have for each image (because of patching)
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_channels = config.num_channels
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embeddings = nn.Conv2d(
in_channels=self.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid", # No padding
) # Image to Patch Embedding
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(
self.num_positions, self.embed_dim
) # Position Embeddings
self.register_buffer(
"position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False,
)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
_, _, height, width = pixel_values.shape # batch_size, num_channels, height, width
# Convolve the `patch_size` kernel over the image, with no overlapping patches since the stride is equal to the kernel size
# The output of the convolution will have shape [batch_size, embed_dim, num_patches_height, num_patches_width]
# where num_patches_height = height // patch_size and num_patches_width = width // patch_size
patch_embeds = self.patch_embeddings(pixel_values) # [batch_size, embed_dim, num_patches_height, num_patches_width]
embeddings = patch_embeds.flatten(2) # [batch_size, embed_dim, num_patches]
# [batch_size, embed_dim, num_patches] -> [batch_size, num_patches, embed_dim] to match the expected input shape for the Transformer
embeddings = embeddings.transpose(1, 2)
embeddings = embeddings + self.position_embedding(self.position_ids) # Add position embeddings (each positional encoding is a vector of size embed_dim)
return embeddings # [batch_size, num_patches, embed_dim]
class SiglipAttention(nn.Module):
"""
Multi-Head Self-Attention module from "Attention is All You Need" paper
"""
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
# For large values of head_dim, the dot products grow large in magnitude, pushing the softmax function into
# regions where it has extremely small gradients. Scaling factor for the dot-product attention counteract this effect.
self.scale = self.head_dim ** -0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) # Key projection
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) # Value projection
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) # Query projection
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) # Output projection
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# hidden_states: [batch_size, num_patches, embed_dim]
batch_size, seq_length, _ = hidden_states.size()
query_states = self.q_proj(hidden_states) # query_states: [batch_size, num_patches, embed_dim]
key_states = self.k_proj(hidden_states) # key_states: [batch_size, num_patches, embed_dim]
value_states = self.v_proj(hidden_states) # value_states: [batch_size, num_patches, embed_dim]
# Reshape the query, key, and value states to separate the attention heads
# We do this to split the embed_dim into multiple heads, each with its own subspace for the multi-head attention
# [batch_size, num_patches, embed_dim] -> [batch_size, num_heads, num_patches, head_dim]
query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
# Compute the dot product attention scores using the formula: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
# where Q is the query, K is the key, V is the value, and d_k is the dimension of the key vectors (head_dim)
# attn_weights: [batch_size, num_heads, num_patches, num_patches]
# Each entry attn_weights[b, h, i, j] represents the attention score from patch i to patch j in head h of batch b
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
if attn_weights.size() != (batch_size, self.num_heads, seq_length, seq_length):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, seq_length, seq_length)}, but is {attn_weights.size()}"
)
# Apply softmax row-wise to get the attention probabilities. attn_weights: [batch_size, num_heads, num_patches, num_patches]
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
# Apply dropout to the attention probabilities only during training
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
# Compute the attention output by multiplying the attention probabilities with the value states
# attn_output: [batch_size, num_heads, num_patches, head_dim]
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, seq_length, self.head_dim):
raise ValueError(
f"Attention output should be of size {(batch_size, self.num_heads, seq_length, self.head_dim)}, but is {attn_output.size()}"
)
# Concatenate the attention output from all heads and project it back to the original embed_dim
# attn_output: batch_size, num_heads, num_patches, head_dim] -> [batch_size, num_patches, num_heads, head_dim]
attn_output = attn_output.transpose(1, 2).contiguous()
# [batch_size, num_patches, num_heads, head_dim] -> [batch_size, num_patches, embed_dim]
attn_output = attn_output.reshape(batch_size, seq_length, self.embed_dim)
# Final linear projection to get the output of the attention layer. Used to mix the information from different heads together
attn_output = self.out_proj(attn_output) # [batch_size, num_patches, embed_dim]
return attn_output, attn_weights # attn_weights is returned for visualization purposes
class SiglipMLP(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
# Expands the embedding size to an intermediate size and then projects it back to the original size
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) # First linear layer
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) # Second linear layer
def forward(self, hidden_states: torch.Tensor) -> torch.Tensot:
# [batch_size, num_patches, embed_dim] -> [batch_size, num_patches, intermediate_size]
hidden_states = self.fc1(hidden_states)
# Apply GELU non-linearity
hidden_states = nn.functional.gelu(hidden_states, approximate="tanh") # heuristics: it just works better
# [batch_size, num_patches, intermediate_size] -> [batch_size, num_patches, embed_dim]
hidden_states = self.fc2(hidden_states)
return hidden_states
class SiglipEncoderLayer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SiglipAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# residual connection: [batch_size, num_patches, embed_dim]
residual = hidden_states
# [batch_size, num_patches, embed_dim] -> [batch_size, num_patches, embed_dim]
hidden_states = self.layer_norm1(hidden_states)
# [batch_size, num_patches, embed_dim] -> [batch_size, num_patches, embed_dim] (you feed in embeddings and get back "contextualized" embeddings)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = hidden_states + residual # Add the residual connection
residual = hidden_states # Update the residual
# [batch_size, num_patches, embed_dim] -> [batch_size, num_patches, embed_dim]
hidden_states = self.layer_norm2(hidden_states)
# Feed-forward layer: takes each embedding vector "independently" and transforms it
# It prepares the sequence of patch embeddings for the next attention layer
# It adds more non-linearity and complexity to the model (DoFs)
hidden_states = self.mlp(hidden_states)
hidden_states = hidden_states + residual # Add the residual connection
return hidden_states
class SiglipEncoder(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
# inputs_embeds: [batch_size, num_patches, embed_dim]
hidden_states = inputs_embeds
for layer in self.layers:
hidden_states = layer(hidden_states)
return hidden_states # [batch_size, num_patches, embed_dim]
class SiglipVisionTransformer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config) # Construct the embeddings from patch, position embeddings.
self.encoder = SiglipEncoder(config) # Construct the Transformer encoder.
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) # Final layer norm
def forward(self, pixel_values: torch.Tensor) -> Tuple:
# pixel_values: [batch_size, num_channels, height, width] -> [batch_size, num_patches, embed_dim]
embeddings_output = self.embeddings(pixel_values)
last_hidden_state = self.encoder(inputs_embeds=embeddings_output)
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
class SiglipVisionModel:
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.vision_model = SiglipVisionTransformer(config)
def forward(self, pixel_values) -> Tuple:
# [batch_size, num_channels, height, width] -> [batch_size, num_patches, embed_dim]
return self.vision_model(pixel_values=pixel_values)