@@ -195,18 +195,28 @@ assert img_size % patch_size == 0, "Image Size must be divisible by patch size"
195
195
num_patches
196
196
197
197
198
- img_size / patch_size, img_size // patch_size
199
-
198
+ img_size = 224
199
+ patch_size = 16
200
200
201
- fig, axs = plt.subplots(nrows=1, ncols=img_size // patch_size, sharex=True, sharey=True)
202
- for i, patch in enumerate(range(0, img_size, patch_size)):
203
- axs[i].imshow(image_permuted[:patch_size, patch:patch+patch_size, :]);
204
- axs[i].set_xlabel(i+1)
205
- axs[i].set_xticks([])
206
- axs[i].set_yticks([])
207
201
202
+ fig, axs = plt.subplots(nrows=img_size // patch_size, ncols=img_size // patch_size, figsize=(num_patches, num_patches), sharex=True, sharey=True)
203
+ # Loop through height and width of the image
204
+ for i, patch_size in enumerate(range(0, img_size, 16)):
205
+ for j, patch_width in enumerate(range(0, img_size, 16)):
206
+ # Plot the permuted image on the different axes
207
+ axs[i, j].imshow(image_permuted[patch_height:patch_height+patch_size, patch_width:patch_width+patch_size,:])
208
+ axs[i, j].set_ylabel(i+1, rotation="horizontal", horizontalalignment="right", verticalalignment="center")
209
+ axs[i, j].set_xlabel(j+1)
210
+ axs[i, j].set_xticks([])
211
+ axs[i, j].set_yticks([])
212
+ axs[i, j].label_outer()
213
+
214
+ # Set up a title for the plot
215
+ fig.suptitle(f"{class_names[label]} -> Patchified", fontsize=14)
216
+ plt.show()
208
217
209
218
219
+ patch_size
210
220
211
221
212
222
0 commit comments