Skip to content

Commit b9d7f39

Browse files
author
Sanyam Kapoor
committed
Upgrade to PyTorch 1.0
1 parent 5cbe5c7 commit b9d7f39

File tree

4 files changed

+16
-22
lines changed

4 files changed

+16
-22
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
data/
22
__pycache__/
3-
pytorch_data/

lenet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ def __init__(self):
3434
('f6', nn.Linear(120, 84)),
3535
('relu6', nn.ReLU()),
3636
('f7', nn.Linear(84, 10)),
37-
('sig7', nn.LogSoftmax())
37+
('sig7', nn.LogSoftmax(dim=-1))
3838
]))
3939

4040
def forward(self, img):
4141
output = self.convnet(img)
42-
output = output.view(-1, 120)
42+
output = output.view(img.size(0), -1)
4343
output = self.fc(output)
4444
return output

requirements.txt

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
numpy==1.13.3
2-
# http://download.pytorch.org/whl/torch-0.2.0.post3-cp36-cp36m-macosx_10_7_x86_64.whl
3-
http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp36-cp36m-manylinux1_x86_64.whl
4-
torchvision==0.1.9
5-
visdom==0.1.6.3
1+
numpy~=1.15.0
2+
torch~=1.0.0
3+
torchvision~=0.2.0
4+
visdom~=0.1.0

run.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,23 @@
22
import torch
33
import torch.nn as nn
44
import torch.optim as optim
5-
from torch.autograd import Variable
65
from torchvision.datasets.mnist import MNIST
76
import torchvision.transforms as transforms
87
from torch.utils.data import DataLoader
98
import visdom
109

1110
viz = visdom.Visdom()
1211

13-
data_train = MNIST('./pytorch_data/mnist',
12+
data_train = MNIST('./data/mnist',
1413
download=True,
1514
transform=transforms.Compose([
16-
transforms.Scale((32, 32)),
15+
transforms.Resize((32, 32)),
1716
transforms.ToTensor()]))
18-
data_test = MNIST('./pytorch_data/mnist',
17+
data_test = MNIST('./data/mnist',
1918
train=False,
2019
download=True,
2120
transform=transforms.Compose([
22-
transforms.Scale((32, 32)),
21+
transforms.Resize((32, 32)),
2322
transforms.ToTensor()]))
2423
data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
2524
data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8)
@@ -43,23 +42,21 @@ def train(epoch):
4342
net.train()
4443
loss_list, batch_list = [], []
4544
for i, (images, labels) in enumerate(data_train_loader):
46-
images, labels = Variable(images), Variable(labels)
47-
4845
optimizer.zero_grad()
4946

5047
output = net(images)
5148

5249
loss = criterion(output, labels)
5350

54-
loss_list.append(loss.data[0])
51+
loss_list.append(loss.detach().cpu().item())
5552
batch_list.append(i+1)
5653

5754
if i % 10 == 0:
58-
print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.data[0]))
55+
print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item()))
5956

6057
# Update Visualization
6158
if viz.check_connection():
62-
cur_batch_win = viz.line(torch.FloatTensor(loss_list), torch.FloatTensor(batch_list),
59+
cur_batch_win = viz.line(torch.Tensor(loss_list), torch.Tensor(batch_list),
6360
win=cur_batch_win, name='current_batch_loss',
6461
update=(None if cur_batch_win is None else 'replace'),
6562
opts=cur_batch_win_opts)
@@ -73,14 +70,13 @@ def test():
7370
total_correct = 0
7471
avg_loss = 0.0
7572
for i, (images, labels) in enumerate(data_test_loader):
76-
images, labels = Variable(images), Variable(labels)
7773
output = net(images)
7874
avg_loss += criterion(output, labels).sum()
79-
pred = output.data.max(1)[1]
80-
total_correct += pred.eq(labels.data.view_as(pred)).sum()
75+
pred = output.detach().max(1)[1]
76+
total_correct += pred.eq(labels.view_as(pred)).sum()
8177

8278
avg_loss /= len(data_test)
83-
print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.data[0], float(total_correct) / len(data_test)))
79+
print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test)))
8480

8581

8682
def train_and_test(epoch):

0 commit comments

Comments
 (0)