@@ -146,10 +146,64 @@ plt.show()
146
146
147
147
148
148
149
+ # Create example values
150
+ height = 224
151
+ width = 224
152
+ color_channels = 3
153
+ patch_size = 16
149
154
155
+ # Calculate the number of patches
156
+ number_of_patches = int((height * width) / patch_size**2)
157
+ number_of_patches
150
158
151
159
160
+ # Input shape
161
+ embedding_layer_input_shape = (height, width, color_channels)
152
162
163
+ # Output shape
164
+ embedding_layer_output_shape = (number_of_patches, patch_size**2 * color_channels)
165
+
166
+ print(f"Input shape (single 2D image): {embedding_layer_input_shape}")
167
+ print(f"Output shape (single 1D sequence of patches): {embedding_layer_output_shape} -> (number_of_patches, embedding_dimension)")
168
+
169
+
170
+
171
+
172
+
173
+ plt.imshow(imgs[0].permute(1,2,0))
174
+
175
+
176
+ image = imgs[0]
177
+
178
+
179
+ image_permuted = image.permute(1,2,0)
180
+
181
+
182
+ patch_size = 16
183
+
184
+
185
+ plt.figure(figsize=(patch_size,patch_size))
186
+ plt.imshow(image_permuted[:patch_size, :, :])
187
+
188
+
189
+ # Setup code to plot top row as patches
190
+ img_size = 224
191
+ patch_size = 16
192
+ num_patches = img_size / patch_size
193
+ assert img_size % patch_size == 0, "Image Size must be divisible by patch size"
194
+
195
+ num_patches
196
+
197
+
198
+ img_size / patch_size, img_size // patch_size
199
+
200
+
201
+ fig, axs = plt.subplots(nrows=1, ncols=img_size // patch_size, sharex=True, sharey=True)
202
+ for i, patch in enumerate(range(0, img_size, patch_size)):
203
+ axs[i].imshow(image_permuted[:patch_size, patch:patch+patch_size, :]);
204
+ axs[i].set_xlabel(i+1)
205
+ axs[i].set_xticks([])
206
+ axs[i].set_yticks([])
153
207
154
208
155
209
0 commit comments