Skip to content

Commit 0a28749

Browse files
committed
Save pytorch introduction tutorial code
1 parent 9f9fcc1 commit 0a28749

File tree

3 files changed

+187
-0
lines changed

3 files changed

+187
-0
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import torch.nn as nn
2+
import torch.nn.functional as F
3+
4+
class Net(nn.Module):
5+
def __init__(self):
6+
super(Net, self).__init__()
7+
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
8+
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
9+
self.conv2_drop = nn.Dropout2d()
10+
self.fc1 = nn.Linear(320, 50)
11+
self.fc2 = nn.Linear(50, 10)
12+
13+
def forward(self, x):
14+
x = F.relu(F.max_pool2d(self.conv1(x), 2))
15+
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
16+
x = x.view(-1, 320)
17+
x = F.relu(self.fc1(x))
18+
x = F.dropout(x, training=self.training)
19+
x = self.fc2(x)
20+
return F.log_softmax(x, dim=1)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import os
2+
import cv2
3+
import torch
4+
import numpy as np
5+
import requests, gzip, os, hashlib
6+
7+
from model import Net
8+
9+
path='Datasets/data' # Path where to save the mnist dataset
10+
def fetch(url):
11+
if os.path.exists(path) is False:
12+
os.makedirs(path)
13+
14+
fp = os.path.join(path, hashlib.md5(url.encode('utf-8')).hexdigest())
15+
if os.path.isfile(fp):
16+
with open(fp, "rb") as f:
17+
data = f.read()
18+
else:
19+
with open(fp, "wb") as f:
20+
data = requests.get(url).content
21+
f.write(data)
22+
return np.frombuffer(gzip.decompress(data), dtype=np.uint8).copy()
23+
24+
test_data = fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
25+
test_targets = fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz")[8:]
26+
27+
# output path
28+
model_path = 'Model/06_pytorch_introduction'
29+
30+
# construct network and load weights
31+
network = Net()
32+
network.load_state_dict(torch.load("Models/06_pytorch_introduction/model.pt"))
33+
network.eval() # set to evaluation mode
34+
35+
36+
for test_image, test_target in zip(test_data, test_targets):
37+
38+
# convert to tensor
39+
inference_image = torch.from_numpy(test_image).float() / 255.0
40+
inference_image = inference_image.unsqueeze(0).unsqueeze(0)
41+
42+
# predict
43+
output = network(inference_image)
44+
pred = output.argmax(dim=1, keepdim=True)
45+
46+
test_image = cv2.resize(test_image, (400, 400))
47+
cv2.imshow(str(pred.item()), test_image)
48+
key = cv2.waitKey(0)
49+
if key == ord('q'): # break on q key
50+
break
51+
52+
cv2.destroyAllWindows()
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import os
2+
import cv2
3+
import numpy as np
4+
from tqdm import tqdm
5+
import requests, gzip, os, hashlib
6+
7+
import torch
8+
import torch.nn.functional as F
9+
import torch.optim as optim
10+
11+
from model import Net
12+
13+
path='Datasets/data'
14+
def fetch(url):
15+
if os.path.exists(path) is False:
16+
os.makedirs(path)
17+
18+
fp = os.path.join(path, hashlib.md5(url.encode('utf-8')).hexdigest())
19+
if os.path.isfile(fp):
20+
with open(fp, "rb") as f:
21+
data = f.read()
22+
else:
23+
with open(fp, "wb") as f:
24+
data = requests.get(url).content
25+
f.write(data)
26+
return np.frombuffer(gzip.decompress(data), dtype=np.uint8).copy()
27+
28+
# load mnist dataset from yann.lecun.com, train data is of shape (60000, 28, 28) and targets are of shape (60000)
29+
train_data = fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
30+
train_targets = fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")[8:]
31+
test_data = fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
32+
test_targets = fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz")[8:]
33+
34+
# uncomment to show images from dataset using OpenCV
35+
# for train_image, train_target in zip(train_data, train_targets):
36+
# train_image = cv2.resize(train_image, (300, 300))
37+
# cv2.imshow("Image", train_image)
38+
# cv2.waitKey(0)
39+
# cv2.destroyAllWindows()
40+
41+
# define hyperparameters
42+
n_epochs = 5
43+
batch_size_train = 64
44+
batch_size_test = 64
45+
learning_rate = 0.001
46+
47+
# reshape data to (items, channels, height, width) and normalize to [0, 1]
48+
train_data = np.expand_dims(train_data, axis=1) / 255.0
49+
test_data = np.expand_dims(test_data, axis=1) / 255.0
50+
51+
# split data into batches of size [(batch_size, 1, 28, 28) ...]
52+
train_batches = [np.array(train_data[i:i+batch_size_train]) for i in range(0, len(train_data), batch_size_train)]
53+
# split targets into batches of size [(batch_size) ...]
54+
train_target_batches = [np.array(train_targets[i:i+batch_size_train]) for i in range(0, len(train_targets), batch_size_train)]
55+
56+
test_batches = [np.array(test_data[i:i+batch_size_test]) for i in range(0, len(test_data), batch_size_test)]
57+
test_target_batches = [np.array(test_targets[i:i+batch_size_test]) for i in range(0, len(test_targets), batch_size_test)]
58+
59+
# create network and optimizer
60+
network = Net()
61+
optimizer = optim.Adam(network.parameters(), lr=learning_rate)
62+
63+
# create training loop
64+
def train(epoch):
65+
network.train()
66+
67+
loss_sum = 0
68+
train_pbar = tqdm(zip(train_batches, train_target_batches), total=len(train_batches))
69+
for data, target in train_pbar:
70+
71+
# convert data to torch.FloatTensor
72+
data = torch.from_numpy(data).float()
73+
target = torch.from_numpy(target).long()
74+
75+
optimizer.zero_grad()
76+
output = network(data)
77+
loss = F.nll_loss(output, target)
78+
loss.backward()
79+
optimizer.step()
80+
81+
loss_sum += loss.item()
82+
train_pbar.set_description(f"Epoch {epoch}, loss: {loss_sum / len(train_batches):.4f}")
83+
84+
# create testing loop
85+
def test(epoch):
86+
network.eval()
87+
88+
correct = 0
89+
loss_sum = 0
90+
val_pbar = tqdm(zip(test_batches, test_target_batches), total=len(test_batches))
91+
with torch.no_grad():
92+
for data, target in val_pbar:
93+
# convert data to torch.FloatTensor
94+
data = torch.from_numpy(data).float()
95+
target = torch.from_numpy(target).long()
96+
97+
output = network(data)
98+
loss_sum += F.nll_loss(output, target).item() / target.size(0)
99+
pred = output.data.max(1, keepdim=True)[1]
100+
correct += pred.eq(target.data.view_as(pred)).sum() / target.size(0)
101+
102+
val_pbar.set_description(f"val_loss: {loss_sum / len(test_batches):.4f}, val_accuracy: {correct / len(test_batches):.4f}")
103+
104+
# train and test the model
105+
for epoch in range(1, n_epochs + 1):
106+
train(epoch)
107+
test(epoch)
108+
109+
# define output path and create folder if not exists
110+
output_path = 'Models/06_pytorch_introduction'
111+
if not os.path.exists(output_path):
112+
os.makedirs(output_path)
113+
114+
# save model.pt to defined output path
115+
torch.save(network.state_dict(), os.path.join(output_path, "model.pt"))

0 commit comments

Comments
 (0)