forked from rosinality/vq-vae-2-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpixelsnail_mnist.py
More file actions
executable file
·60 lines (40 loc) · 1.48 KB
/
pixelsnail_mnist.py
File metadata and controls
executable file
·60 lines (40 loc) · 1.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets
from tqdm import tqdm
from pixelsnail import PixelSNAIL
def train(epoch, loader, model, optimizer, device):
loader = tqdm(loader)
criterion = nn.CrossEntropyLoss()
for i, (img, label) in enumerate(loader):
model.zero_grad()
img = img.to(device)
out = model(img)
loss = criterion(out, img)
loss.backward()
optimizer.step()
_, pred = out.max(1)
correct = (pred == img).float()
accuracy = correct.sum() / img.numel()
loader.set_description(
(f'epoch: {epoch + 1}; loss: {loss.item():.5f}; ' f'acc: {accuracy:.5f}')
)
class PixelTransform:
def __init__(self):
pass
def __call__(self, input):
ar = np.array(input)
return torch.from_numpy(ar).long()
if __name__ == '__main__':
device = 'cuda'
epoch = 10
dataset = datasets.MNIST('.', transform=PixelTransform(), download=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
model = PixelSNAIL([28, 28], 256, 128, 5, 2, 4, 128)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for i in range(10):
train(i, loader, model, optimizer, device)
torch.save(model.state_dict(), f'checkpoint/mnist_{str(i + 1).zfill(3)}.pt')