33import torch .nn .functional as F
44import torch .optim as optim
55from torchvision import datasets , transforms
6- """ I used Thrandis's KFAC from url
7- https://gist.github.com/Thrandis/9b3f75a130ec6c24a64117b7d9304c3f
8- """
96from kfac import KFAC
107
118train_loader = torch .utils .data .DataLoader (
129 datasets .MNIST ('../data' , train = True , download = True ,
1310 transform = transforms .Compose ([
1411 transforms .ToTensor ()])),
15- batch_size = 100 , shuffle = True )
12+ batch_size = 64 , shuffle = True )
1613test_loader = torch .utils .data .DataLoader (
1714 datasets .MNIST ('../data' , train = False , transform = transforms .Compose ([
1815 transforms .ToTensor ()])),
@@ -28,11 +25,11 @@ def __init__(self):
2825 self .fc3 = nn .Linear (84 , 10 )
2926
3027 def forward (self , x ):
31- x = torch . tanh (F .max_pool2d (self .conv1 (x ), 2 ))
32- x = torch . tanh (F .max_pool2d (self .conv2 (x ), 2 ))
28+ x = F . relu (F .max_pool2d (self .conv1 (x ), 2 ))
29+ x = F . relu (F .max_pool2d (self .conv2 (x ), 2 ))
3330 x = x .view (- 1 , 256 )
34- x = torch . tanh (self .fc1 (x ))
35- x = torch . tanh (self .fc2 (x ))
31+ x = F . relu (self .fc1 (x ))
32+ x = F . relu (self .fc2 (x ))
3633 x = self .fc3 (x )
3734 return F .log_softmax (x , dim = 1 )
3835
@@ -50,7 +47,7 @@ def test_loss(model, test_loader):
5047
5148
5249model = LeNet5 ()
53- preconditioner = KFAC (model , 0.001 )
50+ preconditioner = KFAC (model , 0.002 , alpha = 0.05 )
5451lr0 = 0.01
5552optimizer = optim .SGD (model .parameters (), lr = lr0 )
5653TrainLoss , TestLoss = [], []
@@ -61,17 +58,15 @@ def test_loss(model, test_loader):
6158 output = model (data )
6259
6360 loss = F .nll_loss (output , target )
64- for para in model .parameters ():
65- loss += 0.0002 * torch .sum (para * para )
6661
6762 TrainLoss .append (loss .item ())
6863 loss .backward ()
6964 preconditioner .step ()
7065 optimizer .step ()
71- if batch_idx % 10 == 0 :
66+ if batch_idx % 100 == 0 :
7267 print ('Epoch: {}; batch: {}; train loss: {}' .format (epoch , batch_idx , TrainLoss [- 1 ]))
7368
74- lr0 = ( 0.1 ** 0.1 ) * lr0
69+ lr0 = 0.5 * lr0
7570 optimizer .param_groups [0 ]['lr' ] = lr0
7671 TestLoss .append (test_loss (model , test_loader ))
7772 print ('Epoch: {}; best test loss: {}' .format (epoch , min (TestLoss )))
0 commit comments