Skip to content

Commit 07eece5

Browse files
authored
Add files via upload
In LeNet5: replace tanh with ReLU; remove weight decaying; halve step size every epoch; gradient norm clipping
1 parent a9a22ce commit 07eece5

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

demo_LeNet5.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,37 +8,35 @@
88
datasets.MNIST('../data', train=True, download=True,
99
transform=transforms.Compose([
1010
transforms.ToTensor()])),
11-
batch_size=100, shuffle=True)
11+
batch_size=64, shuffle=True)
1212
test_loader = torch.utils.data.DataLoader(
1313
datasets.MNIST('../data', train=False, transform=transforms.Compose([
1414
transforms.ToTensor()])),
1515
batch_size=1000, shuffle=True)
1616

17-
"""input image size for the original LeNet5 is 32x32, here is 28x28"""
17+
"""input image size for the original LeNet5 is 32x32, here is 28x28.
18+
We also replace tanh with ReLU, remove weight decaying. Simpler yet better"""
1819
W1 = torch.tensor(0.1*torch.randn(1*5*5+1, 6), requires_grad=True)
1920
W2 = torch.tensor(0.1*torch.randn(6*5*5+1, 16), requires_grad=True)
20-
W3 = torch.tensor(0.1*torch.randn(16*4*4+1, 120), requires_grad=True)#so here is 4x4, not 5x5
21+
W3 = torch.tensor(0.1*torch.randn(16*4*4+1, 120), requires_grad=True)#here is 4x4, not 5x5
2122
W4 = torch.tensor(0.1*torch.randn(120+1, 84), requires_grad=True)
2223
W5 = torch.tensor(0.1*torch.randn(84+1, 10), requires_grad=True)
2324
Ws = [W1, W2, W3, W4, W5]
2425

2526
def LeNet5(x):
2627
x = F.conv2d(x, W1[:-1].view(6,1,5,5), bias=W1[-1])
27-
x = torch.tanh(F.max_pool2d(x, 2))
28+
x = F.relu(F.max_pool2d(x, 2))#replace tanh with ReLU
2829
x = F.conv2d(x, W2[:-1].view(16,6,5,5), bias=W2[-1])
29-
x = torch.tanh(F.max_pool2d(x, 2))
30-
x = torch.tanh(x.view(-1, 16*4*4).mm(W3[:-1]) + W3[-1])
31-
x = torch.tanh(x.mm(W4[:-1]) + W4[-1])
30+
x = F.relu(F.max_pool2d(x, 2))
31+
x = F.relu(x.view(-1, 16*4*4).mm(W3[:-1]) + W3[-1])
32+
x = F.relu(x.mm(W4[:-1]) + W4[-1])
3233
y = x.mm(W5[:-1]) + W5[-1]
3334
return y
3435

3536
def train_loss(data, target):
3637
y = LeNet5(data)
3738
y = F.log_softmax(y, dim=1)
38-
loss = F.nll_loss(y, target)
39-
for W in Ws:
40-
loss += 0.0002*torch.sum(W*W)
41-
39+
loss = F.nll_loss(y, target)
4240
return loss
4341

4442
def test_loss( ):
@@ -53,19 +51,19 @@ def test_loss( ):
5351

5452
Qs = [[torch.eye(W.shape[0]), torch.eye(W.shape[1])] for W in Ws]
5553
step_size = 0.1
56-
grad_norm_clip_thr = 1e10
54+
grad_norm_clip_thr = 0.1*sum(W.shape[0]*W.shape[1] for W in Ws)**0.5
5755
TrainLoss, TestLoss = [], []
5856
for epoch in range(10):
5957
for batch_idx, (data, target) in enumerate(train_loader):
6058
loss = train_loss(data, target)
6159

6260
grads = grad(loss, Ws, create_graph=True)
6361
TrainLoss.append(loss.item())
64-
if batch_idx%10 == 0:
62+
if batch_idx%100==0:
6563
print('Epoch: {}; batch: {}; train loss: {}'.format(epoch, batch_idx, TrainLoss[-1]))
6664

6765
v = [torch.randn(W.shape) for W in Ws]
68-
Hv = grad(grads, Ws, v)#let Hv=grads if using whitened gradients
66+
Hv = grad(grads, Ws, v)#just let Hv=grads if using whitened gradients
6967
with torch.no_grad():
7068
Qs = [psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv)]
7169
pre_grads = [psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads)]
@@ -75,5 +73,5 @@ def test_loss( ):
7573
Ws[i] -= step_adjust*step_size*pre_grads[i]
7674

7775
TestLoss.append(test_loss())
78-
step_size = (0.1**0.1)*step_size
76+
step_size = 0.5*step_size
7977
print('Epoch: {}; best test loss: {}'.format(epoch, min(TestLoss)))

0 commit comments

Comments
 (0)