1+ import matplotlib .pyplot as plt
2+ import torch
3+ from torch .autograd import grad
4+ import torch .nn .functional as F
5+ from torchvision import datasets , transforms
6+ import preconditioned_stochastic_gradient_descent as psgd #requires PSGD file
7+
8+ train_loader = torch .utils .data .DataLoader (
9+ datasets .MNIST ('../data' , train = True , download = True ,
10+ transform = transforms .Compose ([
11+ transforms .ToTensor ()])),
12+ batch_size = 64 , shuffle = True )
13+ test_loader = torch .utils .data .DataLoader (
14+ datasets .MNIST ('../data' , train = False , transform = transforms .Compose ([
15+ transforms .ToTensor ()])),
16+ batch_size = 1000 , shuffle = True )
17+
18+ """input image size for the original LeNet5 is 32x32, here is 28x28"""
19+ W1 = torch .tensor (0.1 * torch .randn (1 * 5 * 5 + 1 , 6 ), requires_grad = True )
20+ W2 = torch .tensor (0.1 * torch .randn (6 * 5 * 5 + 1 , 16 ), requires_grad = True )
21+ W3 = torch .tensor (0.1 * torch .randn (16 * 4 * 4 + 1 , 120 ), requires_grad = True )#here is 4x4, not 5x5
22+ W4 = torch .tensor (0.1 * torch .randn (120 + 1 , 84 ), requires_grad = True )
23+ W5 = torch .tensor (0.1 * torch .randn (84 + 1 , 10 ), requires_grad = True )
24+ Ws = [W1 , W2 , W3 , W4 , W5 ]
25+
26+ def LeNet5 (x ):
27+ x = F .conv2d (x , W1 [:- 1 ].view (6 ,1 ,5 ,5 ), bias = W1 [- 1 ])
28+ x = F .relu (F .max_pool2d (x , 2 ))
29+ x = F .conv2d (x , W2 [:- 1 ].view (16 ,6 ,5 ,5 ), bias = W2 [- 1 ])
30+ x = F .relu (F .max_pool2d (x , 2 ))
31+ x = F .relu (x .view (- 1 , 16 * 4 * 4 ).mm (W3 [:- 1 ]) + W3 [- 1 ])
32+ x = F .relu (x .mm (W4 [:- 1 ]) + W4 [- 1 ])
33+ y = x .mm (W5 [:- 1 ]) + W5 [- 1 ]
34+ return y
35+
36+ def train_loss (data , target ):
37+ y = LeNet5 (data )
38+ y = F .log_softmax (y , dim = 1 )
39+ loss = F .nll_loss (y , target )
40+ return loss
41+
42+ def test_loss ( ):
43+ num_errs = 0
44+ with torch .no_grad ():
45+ for data , target in test_loader :
46+ y = LeNet5 (data )
47+ _ , pred = torch .max (y , dim = 1 )
48+ num_errs += torch .sum (pred != target )
49+ return num_errs .item ()/ len (test_loader .dataset )
50+
51+ Qs = [[torch .eye (W .shape [0 ]), torch .eye (W .shape [1 ])] for W in Ws ]
52+ step_size = 0.002
53+ damping = 0.0005
54+ grad_norm_clip_thr = 1e10
55+ TrainLoss , TestLoss = [], []
56+ for epoch in range (10 ):
57+ trainloss = 0.0
58+ for batch_idx , (data , target ) in enumerate (train_loader ):
59+ loss = train_loss (data , target )
60+
61+ grads = grad (loss , Ws )#, create_graph=True)
62+ trainloss += loss .item ()
63+
64+ v = [torch .randn (W .shape ) for W in Ws ]
65+ Hv = grads #grad(grads, Ws, v)
66+ with torch .no_grad ():
67+ Qs = [psgd .update_precond_kron (q [0 ], q [1 ], dw , dg + damping * dw ) for (q , dw , dg ) in zip (Qs , v , Hv )]
68+ pre_grads = [psgd .precond_grad_kron (q [0 ], q [1 ], g ) for (q , g ) in zip (Qs , grads )]
69+ grad_norm = torch .sqrt (sum ([torch .sum (g * g ) for g in pre_grads ]))
70+ step_adjust = min (grad_norm_clip_thr / (grad_norm + 1.2e-38 ), 1.0 )
71+ for i in range (len (Ws )):
72+ Ws [i ] -= step_adjust * step_size * pre_grads [i ]
73+
74+ TrainLoss .append (trainloss / len (train_loader .dataset ))
75+ TestLoss .append (test_loss ())
76+ step_size = 0.01 ** (1 / 9 )* step_size
77+ print ('Epoch: {}; train loss: {}; best test loss: {}' .format (epoch , TrainLoss [- 1 ], min (TestLoss )))
78+
79+ plt .subplot (2 ,1 ,1 )
80+ plt .semilogy (range (1 ,11 ), TrainLoss , '-r' , linewidth = 0.2 )
81+ plt .xlabel ('Epochs' )
82+ plt .ylabel ('Train cross entropy loss' )
83+ plt .subplot (2 ,1 ,2 )
84+ plt .semilogy (range (1 ,11 ), TestLoss , '-r' , linewidth = 0.2 )
85+ plt .xlabel ('Epochs' )
86+ plt .ylabel ('Test classification error rate' )
0 commit comments