22import torch
33import torch .nn as nn
44import torch .optim as optim
5- from torch .autograd import Variable
65from torchvision .datasets .mnist import MNIST
76import torchvision .transforms as transforms
87from torch .utils .data import DataLoader
98import visdom
109
1110viz = visdom .Visdom ()
1211
13- data_train = MNIST ('./pytorch_data /mnist' ,
12+ data_train = MNIST ('./data /mnist' ,
1413 download = True ,
1514 transform = transforms .Compose ([
16- transforms .Scale ((32 , 32 )),
15+ transforms .Resize ((32 , 32 )),
1716 transforms .ToTensor ()]))
18- data_test = MNIST ('./pytorch_data /mnist' ,
17+ data_test = MNIST ('./data /mnist' ,
1918 train = False ,
2019 download = True ,
2120 transform = transforms .Compose ([
22- transforms .Scale ((32 , 32 )),
21+ transforms .Resize ((32 , 32 )),
2322 transforms .ToTensor ()]))
2423data_train_loader = DataLoader (data_train , batch_size = 256 , shuffle = True , num_workers = 8 )
2524data_test_loader = DataLoader (data_test , batch_size = 1024 , num_workers = 8 )
@@ -43,23 +42,21 @@ def train(epoch):
4342 net .train ()
4443 loss_list , batch_list = [], []
4544 for i , (images , labels ) in enumerate (data_train_loader ):
46- images , labels = Variable (images ), Variable (labels )
47-
4845 optimizer .zero_grad ()
4946
5047 output = net (images )
5148
5249 loss = criterion (output , labels )
5350
54- loss_list .append (loss .data [ 0 ] )
51+ loss_list .append (loss .detach (). cpu (). item () )
5552 batch_list .append (i + 1 )
5653
5754 if i % 10 == 0 :
58- print ('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch , i , loss .data [ 0 ] ))
55+ print ('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch , i , loss .detach (). cpu (). item () ))
5956
6057 # Update Visualization
6158 if viz .check_connection ():
62- cur_batch_win = viz .line (torch .FloatTensor (loss_list ), torch .FloatTensor (batch_list ),
59+ cur_batch_win = viz .line (torch .Tensor (loss_list ), torch .Tensor (batch_list ),
6360 win = cur_batch_win , name = 'current_batch_loss' ,
6461 update = (None if cur_batch_win is None else 'replace' ),
6562 opts = cur_batch_win_opts )
@@ -73,14 +70,13 @@ def test():
7370 total_correct = 0
7471 avg_loss = 0.0
7572 for i , (images , labels ) in enumerate (data_test_loader ):
76- images , labels = Variable (images ), Variable (labels )
7773 output = net (images )
7874 avg_loss += criterion (output , labels ).sum ()
79- pred = output .data .max (1 )[1 ]
80- total_correct += pred .eq (labels .data . view_as (pred )).sum ()
75+ pred = output .detach () .max (1 )[1 ]
76+ total_correct += pred .eq (labels .view_as (pred )).sum ()
8177
8278 avg_loss /= len (data_test )
83- print ('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss .data [ 0 ] , float (total_correct ) / len (data_test )))
79+ print ('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss .detach (). cpu (). item () , float (total_correct ) / len (data_test )))
8480
8581
8682def train_and_test (epoch ):
0 commit comments