Skip to content

Commit f6dcd54

Browse files
authored
Add files via upload
1 parent 07eece5 commit f6dcd54

File tree

2 files changed

+8
-13
lines changed

2 files changed

+8
-13
lines changed

misc/demo_LeNet5_KFAC.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,13 @@
33
import torch.nn.functional as F
44
import torch.optim as optim
55
from torchvision import datasets, transforms
6-
""" I used Thrandis's KFAC from url
7-
https://gist.github.com/Thrandis/9b3f75a130ec6c24a64117b7d9304c3f
8-
"""
96
from kfac import KFAC
107

118
train_loader = torch.utils.data.DataLoader(
129
datasets.MNIST('../data', train=True, download=True,
1310
transform=transforms.Compose([
1411
transforms.ToTensor()])),
15-
batch_size=100, shuffle=True)
12+
batch_size=64, shuffle=True)
1613
test_loader = torch.utils.data.DataLoader(
1714
datasets.MNIST('../data', train=False, transform=transforms.Compose([
1815
transforms.ToTensor()])),
@@ -28,11 +25,11 @@ def __init__(self):
2825
self.fc3 = nn.Linear(84, 10)
2926

3027
def forward(self, x):
31-
x = torch.tanh(F.max_pool2d(self.conv1(x), 2))
32-
x = torch.tanh(F.max_pool2d(self.conv2(x), 2))
28+
x = F.relu(F.max_pool2d(self.conv1(x), 2))
29+
x = F.relu(F.max_pool2d(self.conv2(x), 2))
3330
x = x.view(-1, 256)
34-
x = torch.tanh(self.fc1(x))
35-
x = torch.tanh(self.fc2(x))
31+
x = F.relu(self.fc1(x))
32+
x = F.relu(self.fc2(x))
3633
x = self.fc3(x)
3734
return F.log_softmax(x, dim=1)
3835

@@ -50,7 +47,7 @@ def test_loss(model, test_loader):
5047

5148

5249
model = LeNet5()
53-
preconditioner = KFAC(model, 0.001)
50+
preconditioner = KFAC(model, 0.002, alpha=0.05)
5451
lr0 = 0.01
5552
optimizer = optim.SGD(model.parameters(), lr=lr0)
5653
TrainLoss, TestLoss = [], []
@@ -61,17 +58,15 @@ def test_loss(model, test_loader):
6158
output = model(data)
6259

6360
loss = F.nll_loss(output, target)
64-
for para in model.parameters():
65-
loss += 0.0002*torch.sum(para*para)
6661

6762
TrainLoss.append(loss.item())
6863
loss.backward()
6964
preconditioner.step()
7065
optimizer.step()
71-
if batch_idx % 10 == 0:
66+
if batch_idx % 100 == 0:
7267
print('Epoch: {}; batch: {}; train loss: {}'.format(epoch, batch_idx, TrainLoss[-1]))
7368

74-
lr0 = (0.1**0.1)*lr0
69+
lr0 = 0.5*lr0
7570
optimizer.param_groups[0]['lr'] = lr0
7671
TestLoss.append(test_loss(model, test_loader))
7772
print('Epoch: {}; best test loss: {}'.format(epoch, min(TestLoss)))

misc/mnist_lenet5.jpg

-2.46 KB
Loading

0 commit comments

Comments
 (0)