Skip to content

Commit 8f12620

Browse files
update lr scheduler and precision
1 parent cc0df99 commit 8f12620

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

ML/Pytorch/Basics/pytorch_lr_ratescheduler.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33
case with a (very) small and simple Feedforward Network training on MNIST
44
dataset with a learning rate scheduler. In this case ReduceLROnPlateau
55
scheduler is used, but can easily be changed to any of the other schedulers
6-
available.
7-
8-
Video explanation: https://youtu.be/P31hB37g4Ak
9-
Got any questions leave a comment on youtube :)
6+
available. I think simply reducing LR by 1/10 or so, when loss plateaus is
7+
a good default.
108
119
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
1210
* 2020-04-10 Initial programming
11+
* 2022-12-19 Updated comments, made sure it works with latest PyTorch
1312
1413
"""
1514

@@ -28,7 +27,9 @@
2827

2928
# Hyperparameters
3029
num_classes = 10
31-
learning_rate = 0.1
30+
learning_rate = (
31+
0.1 # way too high learning rate, but we want to see the scheduler in action
32+
)
3233
batch_size = 128
3334
num_epochs = 100
3435

@@ -47,7 +48,7 @@
4748

4849
# Define Scheduler
4950
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
50-
optimizer, factor=0.1, patience=5, verbose=True
51+
optimizer, factor=0.1, patience=10, verbose=True
5152
)
5253

5354
# Train Network
@@ -67,19 +68,19 @@
6768
losses.append(loss.item())
6869

6970
# backward
71+
optimizer.zero_grad()
7072
loss.backward()
71-
72-
# gradient descent or adam step
73-
# scheduler.step(loss)
7473
optimizer.step()
75-
optimizer.zero_grad()
7674

7775
mean_loss = sum(losses) / len(losses)
76+
mean_loss = round(mean_loss, 2) # we should see difference in loss at 2 decimals
7877

7978
# After each epoch do scheduler.step, note in this scheduler we need to send
80-
# in loss for that epoch!
79+
# in loss for that epoch! This can also be set using validation loss, and also
80+
# in the forward loop we can do on our batch but then we might need to modify
81+
# the patience parameter
8182
scheduler.step(mean_loss)
82-
print(f"Cost at epoch {epoch} is {mean_loss}")
83+
print(f"Average loss for epoch {epoch} was {mean_loss}")
8384

8485
# Check accuracy on training & test to see how good our model
8586
def check_accuracy(loader, model):
@@ -90,6 +91,7 @@ def check_accuracy(loader, model):
9091
with torch.no_grad():
9192
for x, y in loader:
9293
x = x.to(device=device)
94+
x = x.reshape(x.shape[0], -1)
9395
y = y.to(device=device)
9496

9597
scores = model(x)

ML/Pytorch/Basics/pytorch_mixed_precision_example.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def forward(self, x):
3434
# Hyperparameters
3535
in_channel = 1
3636
num_classes = 10
37-
learning_rate = 0.001
37+
learning_rate = 3e-4
3838
batch_size = 100
3939
num_epochs = 5
4040

@@ -74,7 +74,6 @@ def forward(self, x):
7474

7575

7676
# Check accuracy on training & test to see how good our model
77-
7877
def check_accuracy(loader, model):
7978
num_correct = 0
8079
num_samples = 0

0 commit comments

Comments
 (0)