-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbuild_codebook.py
More file actions
38 lines (31 loc) · 1.21 KB
/
build_codebook.py
File metadata and controls
38 lines (31 loc) · 1.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.cluster import KMeans
import numpy as np
import pickle
def extract_all_patches(patch_size=4):
dataset = datasets.MNIST(
root='./data', train=True, download=True,
transform=transforms.ToTensor()
)
all_patches = []
for img, _ in dataset:
img = img.squeeze(0) # (28, 28)
patches = img.unfold(0, patch_size, patch_size).unfold(1, patch_size, patch_size)
patches = patches.contiguous().view(-1, patch_size * patch_size) # (49, 16)
all_patches.append(patches)
all_patches = torch.cat(all_patches, dim=0) # (N_patches_total, patch_dim)
return all_patches
def build_kmeans_codebook(n_clusters=256, patch_size=4):
print("Extracting patches from MNIST...")
all_patches = extract_all_patches(patch_size)
print(f"Total patches: {all_patches.shape[0]}")
kmeans = KMeans(n_clusters=n_clusters, random_state=0)
print("Fitting KMeans...")
kmeans.fit(all_patches.numpy())
with open("codebook.pkl", "wb") as f:
pickle.dump(kmeans, f)
print("Codebook saved to codebook.pkl")
if __name__ == '__main__':
build_kmeans_codebook()