88 datasets .MNIST ('../data' , train = True , download = True ,
99 transform = transforms .Compose ([
1010 transforms .ToTensor ()])),
11- batch_size = 100 , shuffle = True )
11+ batch_size = 64 , shuffle = True )
1212test_loader = torch .utils .data .DataLoader (
1313 datasets .MNIST ('../data' , train = False , transform = transforms .Compose ([
1414 transforms .ToTensor ()])),
1515 batch_size = 1000 , shuffle = True )
1616
17- """input image size for the original LeNet5 is 32x32, here is 28x28"""
17+ """input image size for the original LeNet5 is 32x32, here is 28x28.
18+ We also replace tanh with ReLU, remove weight decaying. Simpler yet better"""
1819W1 = torch .tensor (0.1 * torch .randn (1 * 5 * 5 + 1 , 6 ), requires_grad = True )
1920W2 = torch .tensor (0.1 * torch .randn (6 * 5 * 5 + 1 , 16 ), requires_grad = True )
20- W3 = torch .tensor (0.1 * torch .randn (16 * 4 * 4 + 1 , 120 ), requires_grad = True )#so here is 4x4, not 5x5
21+ W3 = torch .tensor (0.1 * torch .randn (16 * 4 * 4 + 1 , 120 ), requires_grad = True )#here is 4x4, not 5x5
2122W4 = torch .tensor (0.1 * torch .randn (120 + 1 , 84 ), requires_grad = True )
2223W5 = torch .tensor (0.1 * torch .randn (84 + 1 , 10 ), requires_grad = True )
2324Ws = [W1 , W2 , W3 , W4 , W5 ]
2425
2526def LeNet5 (x ):
2627 x = F .conv2d (x , W1 [:- 1 ].view (6 ,1 ,5 ,5 ), bias = W1 [- 1 ])
27- x = torch . tanh (F .max_pool2d (x , 2 ))
28+ x = F . relu (F .max_pool2d (x , 2 ))#replace tanh with ReLU
2829 x = F .conv2d (x , W2 [:- 1 ].view (16 ,6 ,5 ,5 ), bias = W2 [- 1 ])
29- x = torch . tanh (F .max_pool2d (x , 2 ))
30- x = torch . tanh (x .view (- 1 , 16 * 4 * 4 ).mm (W3 [:- 1 ]) + W3 [- 1 ])
31- x = torch . tanh (x .mm (W4 [:- 1 ]) + W4 [- 1 ])
30+ x = F . relu (F .max_pool2d (x , 2 ))
31+ x = F . relu (x .view (- 1 , 16 * 4 * 4 ).mm (W3 [:- 1 ]) + W3 [- 1 ])
32+ x = F . relu (x .mm (W4 [:- 1 ]) + W4 [- 1 ])
3233 y = x .mm (W5 [:- 1 ]) + W5 [- 1 ]
3334 return y
3435
3536def train_loss (data , target ):
3637 y = LeNet5 (data )
3738 y = F .log_softmax (y , dim = 1 )
38- loss = F .nll_loss (y , target )
39- for W in Ws :
40- loss += 0.0002 * torch .sum (W * W )
41-
39+ loss = F .nll_loss (y , target )
4240 return loss
4341
4442def test_loss ( ):
@@ -53,19 +51,19 @@ def test_loss( ):
5351
5452Qs = [[torch .eye (W .shape [0 ]), torch .eye (W .shape [1 ])] for W in Ws ]
5553step_size = 0.1
56- grad_norm_clip_thr = 1e10
54+ grad_norm_clip_thr = 0.1 * sum ( W . shape [ 0 ] * W . shape [ 1 ] for W in Ws ) ** 0.5
5755TrainLoss , TestLoss = [], []
5856for epoch in range (10 ):
5957 for batch_idx , (data , target ) in enumerate (train_loader ):
6058 loss = train_loss (data , target )
6159
6260 grads = grad (loss , Ws , create_graph = True )
6361 TrainLoss .append (loss .item ())
64- if batch_idx % 10 == 0 :
62+ if batch_idx % 100 == 0 :
6563 print ('Epoch: {}; batch: {}; train loss: {}' .format (epoch , batch_idx , TrainLoss [- 1 ]))
6664
6765 v = [torch .randn (W .shape ) for W in Ws ]
68- Hv = grad (grads , Ws , v )#let Hv=grads if using whitened gradients
66+ Hv = grad (grads , Ws , v )#just let Hv=grads if using whitened gradients
6967 with torch .no_grad ():
7068 Qs = [psgd .update_precond_kron (q [0 ], q [1 ], dw , dg ) for (q , dw , dg ) in zip (Qs , v , Hv )]
7169 pre_grads = [psgd .precond_grad_kron (q [0 ], q [1 ], g ) for (q , g ) in zip (Qs , grads )]
@@ -75,5 +73,5 @@ def test_loss( ):
7573 Ws [i ] -= step_adjust * step_size * pre_grads [i ]
7674
7775 TestLoss .append (test_loss ())
78- step_size = ( 0.1 ** 0.1 ) * step_size
76+ step_size = 0.5 * step_size
7977 print ('Epoch: {}; best test loss: {}' .format (epoch , min (TestLoss )))
0 commit comments